diff --git a/events.out.tfevents.1660117225.t1v-n-eedfb410-w-0.55204.0.v2 b/events.out.tfevents.1660117225.t1v-n-eedfb410-w-0.55204.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..6dc7e6c2d781f2302081d90cb83c1e4ddac6ef43 --- /dev/null +++ b/events.out.tfevents.1660117225.t1v-n-eedfb410-w-0.55204.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e2942e6466c560f7eac1111e0c5b04e5ba30bc241b0d7408d895e2e7cad769c +size 40 diff --git a/events.out.tfevents.1660130897.t1v-n-eedfb410-w-0.8420.0.v2 b/events.out.tfevents.1660130897.t1v-n-eedfb410-w-0.8420.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..dbbccf557b1ea78fccbf363ca38f42ee6ebf74cf --- /dev/null +++ b/events.out.tfevents.1660130897.t1v-n-eedfb410-w-0.8420.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d3db5592c5e247b36ba97b636878b108457cd33220828a88069711c1ae23838 +size 40 diff --git a/events.out.tfevents.1660143983.t1v-n-eedfb410-w-0.3332902.0.v2 b/events.out.tfevents.1660143983.t1v-n-eedfb410-w-0.3332902.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..8ca915a861a5ade870237abe1cdd285f47a26291 --- /dev/null +++ b/events.out.tfevents.1660143983.t1v-n-eedfb410-w-0.3332902.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f299990d7dfa6f5112ae7da90bcf6a5217143514806dc1b21b14594cfc58a389 +size 40 diff --git a/events.out.tfevents.1660145355.t1v-n-eedfb410-w-0.2349240.0.v2 b/events.out.tfevents.1660145355.t1v-n-eedfb410-w-0.2349240.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..ae3d0e018ea116b43e41ff456aab74add62f37dc --- /dev/null +++ b/events.out.tfevents.1660145355.t1v-n-eedfb410-w-0.2349240.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:646540f39538eecbf2b7e4d42ca372657d297ba5bd1f4206725e039acbea46a4 +size 40 diff --git a/events.out.tfevents.1660206880.t1v-n-eedfb410-w-0.1479163.0.v2 b/events.out.tfevents.1660206880.t1v-n-eedfb410-w-0.1479163.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..218da969145353380f18ac00e7e164f1533bc4b7 --- /dev/null +++ b/events.out.tfevents.1660206880.t1v-n-eedfb410-w-0.1479163.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26e0c8ec22f6f03deb1c4c2b33e27affb7898c696044d480d3ed8861cbe6ad58 +size 40 diff --git a/events.out.tfevents.1660208728.t1v-n-eedfb410-w-0.503538.0.v2 b/events.out.tfevents.1660208728.t1v-n-eedfb410-w-0.503538.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..919ed04c2e10dda91948378866a4b9a517646774 --- /dev/null +++ b/events.out.tfevents.1660208728.t1v-n-eedfb410-w-0.503538.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09793df210508f083a2bc11c7343fd21bdd95cab1d58deef35d0921cd281ffa8 +size 40 diff --git a/events.out.tfevents.1660218137.t1v-n-eedfb410-w-0.2916397.0.v2 b/events.out.tfevents.1660218137.t1v-n-eedfb410-w-0.2916397.0.v2 new file mode 100644 index 0000000000000000000000000000000000000000..ee9ce7652cfff383009190d5c9d39667484f4d1d --- /dev/null +++ b/events.out.tfevents.1660218137.t1v-n-eedfb410-w-0.2916397.0.v2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95289ecfa3686050d8cfdb3c04ca3ef3ff1d6fe0c12a59ee107d52e0ebdb0d29 +size 40 diff --git a/flax_model.msgpack b/flax_model.msgpack index f16ff5d500887fa1549f2a8b169a89a7fdede0d8..679dd9c4373e2413b7990cfcd98d4399c31f0bca 100644 --- a/flax_model.msgpack +++ b/flax_model.msgpack @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d6e7fe76ddde6be27c0129735dc3bb50191fff23d8f87075f1335970abf06211 +oid sha256:2824056b9c2f157ff862c17877b5aa4a77f0f6107345973495c02df3828b7469 size 3850218852 diff --git a/run.recover.sh b/run.recover.sh index 77ad3fdef2ecb8420c6e0595876044a70f63b930..632a3366c5795aa6a8ba0fbf3e039cbab3f7f037 100755 --- a/run.recover.sh +++ b/run.recover.sh @@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c --per_device_train_batch_size="2" \ --per_device_eval_batch_size="2" \ --gradient_accumulation_steps="1" \ - --precision="full_mixed" \ + --precision="half_mixed" \ --matmul_precision="bfloat16" \ - --multisteps \ --learning_rate="6.394633237505332e-05" \ --skip_steps="275000" \ --warmup_steps="2000" \ diff --git a/run.sh b/run.sh index 875897821c3c7dbe343de036c763893d828f771d..6adf9ee68382637f5cf5e196fcb0bf81f0df2bb3 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,6 @@ +# See https://github.com/sanchit-gandhi/seq2seq-speech/issues/23#issuecomment-1122183173: do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab +# Let's also not add extra remove_punctuation +# And limit max duration to 25 seconds WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \ --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \ @@ -11,7 +14,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c --precision="full_mixed" \ --matmul_precision="bfloat16" \ --multisteps \ - --learning_rate="1e-4" \ + --learning_rate="2e-5" \ --warmup_steps="2000" \ --length_column_name="input_length" \ --evaluation_strategy="steps" \ @@ -32,7 +35,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c --mask_feature_length="64" \ --gradient_checkpointing \ --min_duration_in_seconds="0.5" \ - --max_duration_in_seconds="30.0" \ + --max_duration_in_seconds="25.0" \ --use_auth_token \ --seed="42" \ --group_by_length \ @@ -40,10 +43,5 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c --push_to_hub \ --preprocessing_num_workers="32" \ --ctc_zero_infinity \ - --do_lower_case \ --wandb_project="wav2vec2" \ --wandb_name="wav2vec2-1b-npsc-nst-tpu" \ - --remove_punctuation - - -# --fp16 diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py index a3308794e399464a6649e4f73284f7a0586a92ce..4a5d9404604defc9d7cdab04832844ec56ce7978 100644 --- a/run_flax_speech_recognition_ctc.py +++ b/run_flax_speech_recognition_ctc.py @@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): ) @classmethod - def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" # downcast optimizer state to bf16 if mixed-precision training opt_state = tx.init(to_dtype(params)) if tx is not None else None return cls( - step=0, + step=step, apply_fn=apply_fn, params=params, tx=tx, @@ -1339,6 +1339,7 @@ def main(): # Setup train state state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, apply_fn=model.__call__, get_attention_mask_fn=model._get_feature_vector_attention_mask, params=model.params, @@ -1517,14 +1518,13 @@ def main(): if training_args.do_train: # ======================== Training ================================ train_start = time.time() + # Create sampling rng + rng, input_rng = jax.random.split(rng) if epoch < skip_epochs: logger.info(f"Skipping epoch {epoch + 1}") continue - # Create sampling rng - rng, input_rng = jax.random.split(rng) - # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) diff --git a/special_tokens_map.json b/special_tokens_map.json index 218961f90a6177a4ba0a6afe44ea30089da575ca..ef2d6eddbcdb9e27f2632915c04d2f11abe81542 100644 --- a/special_tokens_map.json +++ b/special_tokens_map.json @@ -399,6 +399,104 @@ "rstrip": false, "single_word": false }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, { "content": "", "lstrip": false, diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log index 23926ef77b11a5353a1de250a3470863e004a83a..737c280688327345d0bc0adb1feed16f758dd2d8 120000 --- a/wandb/debug-internal.log +++ b/wandb/debug-internal.log @@ -1 +1 @@ -run-20220805_230151-2y71vcu4/logs/debug-internal.log \ No newline at end of file +run-20220811_101752-mzjvp6ho/logs/debug-internal.log \ No newline at end of file diff --git a/wandb/debug.log b/wandb/debug.log index 279853d4bafde426620c831150a806e4faba7184..fcc4539ff7612a0c410affe657823638cdabe104 120000 --- a/wandb/debug.log +++ b/wandb/debug.log @@ -1 +1 @@ -run-20220805_230151-2y71vcu4/logs/debug.log \ No newline at end of file +run-20220811_101752-mzjvp6ho/logs/debug.log \ No newline at end of file diff --git a/wandb/latest-run b/wandb/latest-run index f069a7ae28491f2d6f3aa9f800a46a1881be948d..e1726b485d34f15153548233ee89da31bbe2f650 120000 --- a/wandb/latest-run +++ b/wandb/latest-run @@ -1 +1 @@ -run-20220805_230151-2y71vcu4 \ No newline at end of file +run-20220811_101752-mzjvp6ho \ No newline at end of file diff --git a/wandb/run-20220810_073735-23avj35z/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220810_073735-23avj35z/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..a3308794e399464a6649e4f73284f7a0586a92ce --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1631 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=0, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220810_073735-23avj35z/files/config.yaml b/wandb/run-20220810_073735-23avj35z/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbd4f52063597c3d5171962b28f4081c0f409fea --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660117055 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220810_073735-23avj35z/files/diff.patch b/wandb/run-20220810_073735-23avj35z/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..76d634027256b67a8e0df0887db6f4601551fdc5 --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/diff.patch @@ -0,0 +1,27 @@ +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..9213b33 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220810_073735-23avj35z/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..bcac724 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220810_073735-23avj35z/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..1406fac 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220810_073735-23avj35z +\ No newline at end of file diff --git a/wandb/run-20220810_073735-23avj35z/files/output.log b/wandb/run-20220810_073735-23avj35z/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..60d8fa59823d57060ba46af7d392b7a822911a43 --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b15353b15528bc042b1df6aa006abb62291e8c20dc7fa0bfe25bddcdf5307ef +size 166570 diff --git a/wandb/run-20220810_073735-23avj35z/files/requirements.txt b/wandb/run-20220810_073735-23avj35z/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ef78cbdea431c3d66bc5af51443394cb7955eab --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.1 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220810_073735-23avj35z/files/wandb-metadata.json b/wandb/run-20220810_073735-23avj35z/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a8f24e8d951899e682351b78def6c558b36b1c77 --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/wandb-metadata.json @@ -0,0 +1,70 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-10T07:37:39.012020", + "startedAt": "2022-08-10T07:37:35.560272", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=./", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=full_mixed", + "--matmul_precision=bfloat16", + "--multisteps", + "--learning_rate=6.394633237505332e-05", + "--skip_steps=275000", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=30.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--do_lower_case", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)", + "--remove_punctuation" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220810_073735-23avj35z/files/wandb-summary.json b/wandb/run-20220810_073735-23avj35z/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..d008597910ce720253c8c67f282402f80829956a --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/files/wandb-summary.json @@ -0,0 +1 @@ +{"train/grad_norm": 6.5625, "layer_grad_norm/": {"lm_head": {"bias": 0.031982421875, "kernel": 4.625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.0556640625, "scale": 0.06103515625}, "layers": {"0": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.04150390625, "kernel": 0.2431640625}, "q_proj": {"bias": 0.002899169921875, "kernel": 0.031005859375}, "v_proj": {"bias": 0.037109375, "kernel": 0.265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.04443359375, "kernel": 0.515625}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.439453125}}, "final_layer_norm": {"bias": 0.146484375, "scale": 0.322265625}, "layer_norm": {"bias": 0.0703125, "scale": 0.07080078125}}, "1": {"attention": {"k_proj": {"bias": 3.4332275390625e-05, "kernel": 0.03955078125}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.134765625}, "q_proj": {"bias": 0.0035247802734375, "kernel": 0.0439453125}, "v_proj": {"bias": 0.02880859375, "kernel": 0.111328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.3359375}, "output_dense": {"bias": 0.0157470703125, "kernel": 0.259765625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.05712890625}, "layer_norm": {"bias": 0.05712890625, "scale": 0.039794921875}}, "10": {"attention": {"k_proj": {"bias": 3.600120544433594e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.01416015625, "kernel": 0.2001953125}, "q_proj": {"bias": 0.0078125, "kernel": 0.12255859375}, "v_proj": {"bias": 0.022705078125, "kernel": 0.2001953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.328125}, "output_dense": {"bias": 0.013671875, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04931640625, "scale": 0.0341796875}}, "11": {"attention": {"k_proj": {"bias": 8.344650268554688e-05, "kernel": 0.158203125}, "out_proj": {"bias": 0.0142822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.0087890625, "kernel": 0.130859375}, "v_proj": {"bias": 0.024658203125, "kernel": 0.28515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.310546875}, "output_dense": {"bias": 0.013916015625, "kernel": 0.244140625}}, "final_layer_norm": {"bias": 0.03271484375, "scale": 0.0308837890625}, "layer_norm": {"bias": 0.05029296875, "scale": 0.0439453125}}, "12": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.244140625}, "q_proj": {"bias": 0.0081787109375, "kernel": 0.1162109375}, "v_proj": {"bias": 0.023681640625, "kernel": 0.2294921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.32421875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.255859375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.04248046875}, "layer_norm": {"bias": 0.046630859375, "scale": 0.0546875}}, "13": {"attention": {"k_proj": {"bias": 0.00012493133544921875, "kernel": 0.15625}, "out_proj": {"bias": 0.01519775390625, "kernel": 0.330078125}, "q_proj": {"bias": 0.0111083984375, "kernel": 0.158203125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.34375}, "output_dense": {"bias": 0.01513671875, "kernel": 0.3125}}, "final_layer_norm": {"bias": 0.040283203125, "scale": 0.032958984375}, "layer_norm": {"bias": 0.051513671875, "scale": 0.091796875}}, "14": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.015625, "kernel": 0.2412109375}, "q_proj": {"bias": 0.006256103515625, "kernel": 0.099609375}, "v_proj": {"bias": 0.0235595703125, "kernel": 0.2275390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0257568359375, "kernel": 0.39453125}, "output_dense": {"bias": 0.015380859375, "kernel": 0.33984375}}, "final_layer_norm": {"bias": 0.05126953125, "scale": 0.05517578125}, "layer_norm": {"bias": 0.041748046875, "scale": 0.03076171875}}, "15": {"attention": {"k_proj": {"bias": 0.0003070831298828125, "kernel": 0.1806640625}, "out_proj": {"bias": 0.015625, "kernel": 0.5078125}, "q_proj": {"bias": 0.0106201171875, "kernel": 0.173828125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.361328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.376953125}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.349609375}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.033447265625}, "layer_norm": {"bias": 0.048095703125, "scale": 0.072265625}}, "16": {"attention": {"k_proj": {"bias": 6.389617919921875e-05, "kernel": 0.1025390625}, "out_proj": {"bias": 0.016357421875, "kernel": 0.267578125}, "q_proj": {"bias": 0.0057373046875, "kernel": 0.1005859375}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.220703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0223388671875, "kernel": 0.359375}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.341796875}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.043212890625, "scale": 0.034912109375}}, "17": {"attention": {"k_proj": {"bias": 4.57763671875e-05, "kernel": 0.0927734375}, "out_proj": {"bias": 0.0172119140625, "kernel": 0.23046875}, "q_proj": {"bias": 0.005889892578125, "kernel": 0.087890625}, "v_proj": {"bias": 0.0244140625, "kernel": 0.2177734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.390625}, "output_dense": {"bias": 0.01708984375, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.041259765625, "scale": 0.036376953125}, "layer_norm": {"bias": 0.0439453125, "scale": 0.0341796875}}, "18": {"attention": {"k_proj": {"bias": 0.000247955322265625, "kernel": 0.126953125}, "out_proj": {"bias": 0.017578125, "kernel": 0.369140625}, "q_proj": {"bias": 0.0076904296875, "kernel": 0.1337890625}, "v_proj": {"bias": 0.027587890625, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.44921875}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.04443359375, "scale": 0.03857421875}, "layer_norm": {"bias": 0.048583984375, "scale": 0.039794921875}}, "19": {"attention": {"k_proj": {"bias": 8.678436279296875e-05, "kernel": 0.140625}, "out_proj": {"bias": 0.017822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.009033203125, "kernel": 0.140625}, "v_proj": {"bias": 0.0286865234375, "kernel": 0.283203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.474609375}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.421875}}, "final_layer_norm": {"bias": 0.041748046875, "scale": 0.0380859375}, "layer_norm": {"bias": 0.052734375, "scale": 0.04052734375}}, "2": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.07421875}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.2060546875}, "q_proj": {"bias": 0.006195068359375, "kernel": 0.06982421875}, "v_proj": {"bias": 0.03173828125, "kernel": 0.181640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.390625}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.047119140625, "scale": 0.03173828125}, "layer_norm": {"bias": 0.0556640625, "scale": 0.07275390625}}, "20": {"attention": {"k_proj": {"bias": 2.110004425048828e-05, "kernel": 0.095703125}, "out_proj": {"bias": 0.0185546875, "kernel": 0.142578125}, "q_proj": {"bias": 0.005157470703125, "kernel": 0.0947265625}, "v_proj": {"bias": 0.0263671875, "kernel": 0.140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0250244140625, "kernel": 0.4765625}, "output_dense": {"bias": 0.018310546875, "kernel": 0.390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.046142578125, "scale": 0.038330078125}}, "21": {"attention": {"k_proj": {"bias": 4.00543212890625e-05, "kernel": 0.1259765625}, "out_proj": {"bias": 0.0189208984375, "kernel": 0.2216796875}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.12890625}, "v_proj": {"bias": 0.02734375, "kernel": 0.203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0267333984375, "kernel": 0.51953125}, "output_dense": {"bias": 0.0185546875, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.04541015625, "scale": 0.04736328125}, "layer_norm": {"bias": 0.044189453125, "scale": 0.054443359375}}, "22": {"attention": {"k_proj": {"bias": 3.3855438232421875e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.019775390625, "kernel": 0.240234375}, "q_proj": {"bias": 0.006011962890625, "kernel": 0.11279296875}, "v_proj": {"bias": 0.028076171875, "kernel": 0.21875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0269775390625, "kernel": 0.515625}, "output_dense": {"bias": 0.0194091796875, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.047119140625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.0458984375}}, "23": {"attention": {"k_proj": {"bias": 0.0001087188720703125, "kernel": 0.16015625}, "out_proj": {"bias": 0.0198974609375, "kernel": 0.443359375}, "q_proj": {"bias": 0.008544921875, "kernel": 0.1630859375}, "v_proj": {"bias": 0.03173828125, "kernel": 0.35546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0263671875, "kernel": 0.53125}, "output_dense": {"bias": 0.01953125, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.04638671875}, "layer_norm": {"bias": 0.05615234375, "scale": 0.056396484375}}, "24": {"attention": {"k_proj": {"bias": 6.246566772460938e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0191650390625, "kernel": 0.36328125}, "q_proj": {"bias": 0.00933837890625, "kernel": 0.18359375}, "v_proj": {"bias": 0.03271484375, "kernel": 0.328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02685546875, "kernel": 0.5390625}, "output_dense": {"bias": 0.01904296875, "kernel": 0.37890625}}, "final_layer_norm": {"bias": 0.04736328125, "scale": 0.04345703125}, "layer_norm": {"bias": 0.0625, "scale": 0.041015625}}, "25": {"attention": {"k_proj": {"bias": 6.079673767089844e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0196533203125, "kernel": 0.3125}, "q_proj": {"bias": 0.00860595703125, "kernel": 0.16015625}, "v_proj": {"bias": 0.03271484375, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.55859375}, "output_dense": {"bias": 0.01953125, "kernel": 0.375}}, "final_layer_norm": {"bias": 0.050537109375, "scale": 0.0478515625}, "layer_norm": {"bias": 0.06005859375, "scale": 0.06298828125}}, "26": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.01953125, "kernel": 0.29296875}, "q_proj": {"bias": 0.01025390625, "kernel": 0.177734375}, "v_proj": {"bias": 0.0341796875, "kernel": 0.296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.026611328125, "kernel": 0.51171875}, "output_dense": {"bias": 0.01904296875, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.060791015625, "scale": 0.06396484375}}, "27": {"attention": {"k_proj": {"bias": 0.00011396408081054688, "kernel": 0.2021484375}, "out_proj": {"bias": 0.01806640625, "kernel": 0.44921875}, "q_proj": {"bias": 0.01068115234375, "kernel": 0.2138671875}, "v_proj": {"bias": 0.03466796875, "kernel": 0.435546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.515625}, "output_dense": {"bias": 0.0181884765625, "kernel": 0.36328125}}, "final_layer_norm": {"bias": 0.05078125, "scale": 0.045654296875}, "layer_norm": {"bias": 0.06640625, "scale": 0.04931640625}}, "28": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.20703125}, "out_proj": {"bias": 0.0164794921875, "kernel": 0.392578125}, "q_proj": {"bias": 0.01165771484375, "kernel": 0.208984375}, "v_proj": {"bias": 0.031494140625, "kernel": 0.404296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.45703125}, "output_dense": {"bias": 0.016357421875, "kernel": 0.326171875}}, "final_layer_norm": {"bias": 0.04248046875, "scale": 0.044921875}, "layer_norm": {"bias": 0.0673828125, "scale": 0.08447265625}}, "29": {"attention": {"k_proj": {"bias": 9.918212890625e-05, "kernel": 0.267578125}, "out_proj": {"bias": 0.0157470703125, "kernel": 0.28515625}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.265625}, "v_proj": {"bias": 0.02978515625, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.494140625}, "output_dense": {"bias": 0.01531982421875, "kernel": 0.296875}}, "final_layer_norm": {"bias": 0.03955078125, "scale": 0.03515625}, "layer_norm": {"bias": 0.0654296875, "scale": 0.061279296875}}, "3": {"attention": {"k_proj": {"bias": 0.00012111663818359375, "kernel": 0.0986328125}, "out_proj": {"bias": 0.016845703125, "kernel": 0.314453125}, "q_proj": {"bias": 0.00726318359375, "kernel": 0.0888671875}, "v_proj": {"bias": 0.0283203125, "kernel": 0.2470703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0242919921875, "kernel": 0.3828125}, "output_dense": {"bias": 0.0150146484375, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.03125}, "layer_norm": {"bias": 0.0498046875, "scale": 0.0380859375}}, "30": {"attention": {"k_proj": {"bias": 0.0001220703125, "kernel": 0.13671875}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.328125}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.138671875}, "v_proj": {"bias": 0.029296875, "kernel": 0.3671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023681640625, "kernel": 0.51953125}, "output_dense": {"bias": 0.01446533203125, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03564453125}, "layer_norm": {"bias": 0.04931640625, "scale": 0.037109375}}, "31": {"attention": {"k_proj": {"bias": 0.00010347366333007812, "kernel": 0.14453125}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.29296875}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.134765625}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.314453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02392578125, "kernel": 0.51953125}, "output_dense": {"bias": 0.01385498046875, "kernel": 0.2578125}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.03662109375}, "layer_norm": {"bias": 0.039306640625, "scale": 0.0291748046875}}, "32": {"attention": {"k_proj": {"bias": 8.296966552734375e-05, "kernel": 0.15625}, "out_proj": {"bias": 0.01263427734375, "kernel": 0.28125}, "q_proj": {"bias": 0.0079345703125, "kernel": 0.1533203125}, "v_proj": {"bias": 0.0264892578125, "kernel": 0.4921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0216064453125, "kernel": 0.431640625}, "output_dense": {"bias": 0.01129150390625, "kernel": 0.212890625}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03271484375}, "layer_norm": {"bias": 0.046630859375, "scale": 0.05419921875}}, "33": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.166015625}, "out_proj": {"bias": 0.01092529296875, "kernel": 0.2275390625}, "q_proj": {"bias": 0.008544921875, "kernel": 0.166015625}, "v_proj": {"bias": 0.023193359375, "kernel": 0.34765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.390625}, "output_dense": {"bias": 0.00897216796875, "kernel": 0.1875}}, "final_layer_norm": {"bias": 0.04345703125, "scale": 0.0361328125}, "layer_norm": {"bias": 0.039794921875, "scale": 0.0498046875}}, "34": {"attention": {"k_proj": {"bias": 0.0002346038818359375, "kernel": 0.158203125}, "out_proj": {"bias": 0.0081787109375, "kernel": 0.181640625}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.14453125}, "v_proj": {"bias": 0.0177001953125, "kernel": 0.25390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01434326171875, "kernel": 0.291015625}, "output_dense": {"bias": 0.0072021484375, "kernel": 0.1748046875}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.025146484375}, "layer_norm": {"bias": 0.03369140625, "scale": 0.026611328125}}, "35": {"attention": {"k_proj": {"bias": 0.0001506805419921875, "kernel": 0.10791015625}, "out_proj": {"bias": 0.00640869140625, "kernel": 0.2109375}, "q_proj": {"bias": 0.004852294921875, "kernel": 0.10791015625}, "v_proj": {"bias": 0.01177978515625, "kernel": 0.21484375}}, "feed_forward": {"intermediate_dense": {"bias": 0.010498046875, "kernel": 0.2119140625}, "output_dense": {"bias": 0.005889892578125, "kernel": 0.15234375}}, "final_layer_norm": {"bias": 0.0206298828125, "scale": 0.0220947265625}, "layer_norm": {"bias": 0.024169921875, "scale": 0.02880859375}}, "36": {"attention": {"k_proj": {"bias": 4.410743713378906e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.005645751953125, "kernel": 0.1552734375}, "q_proj": {"bias": 0.00445556640625, "kernel": 0.095703125}, "v_proj": {"bias": 0.00946044921875, "kernel": 0.14453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0089111328125, "kernel": 0.177734375}, "output_dense": {"bias": 0.0050048828125, "kernel": 0.111328125}}, "final_layer_norm": {"bias": 0.017578125, "scale": 0.01513671875}, "layer_norm": {"bias": 0.0191650390625, "scale": 0.01806640625}}, "37": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.0849609375}, "out_proj": {"bias": 0.004913330078125, "kernel": 0.11474609375}, "q_proj": {"bias": 0.00390625, "kernel": 0.0830078125}, "v_proj": {"bias": 0.00897216796875, "kernel": 0.1318359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.00823974609375, "kernel": 0.16796875}, "output_dense": {"bias": 0.004241943359375, "kernel": 0.09716796875}}, "final_layer_norm": {"bias": 0.015869140625, "scale": 0.01434326171875}, "layer_norm": {"bias": 0.019287109375, "scale": 0.015869140625}}, "38": {"attention": {"k_proj": {"bias": 5.650520324707031e-05, "kernel": 0.09130859375}, "out_proj": {"bias": 0.0040283203125, "kernel": 0.11865234375}, "q_proj": {"bias": 0.00396728515625, "kernel": 0.08642578125}, "v_proj": {"bias": 0.007354736328125, "kernel": 0.1279296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0072021484375, "kernel": 0.150390625}, "output_dense": {"bias": 0.0034637451171875, "kernel": 0.09423828125}}, "final_layer_norm": {"bias": 0.0152587890625, "scale": 0.0146484375}, "layer_norm": {"bias": 0.0162353515625, "scale": 0.0135498046875}}, "39": {"attention": {"k_proj": {"bias": 5.316734313964844e-05, "kernel": 0.09619140625}, "out_proj": {"bias": 0.0030975341796875, "kernel": 0.09619140625}, "q_proj": {"bias": 0.00408935546875, "kernel": 0.0908203125}, "v_proj": {"bias": 0.006011962890625, "kernel": 0.10986328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.005401611328125, "kernel": 0.12109375}, "output_dense": {"bias": 0.0025634765625, "kernel": 0.08642578125}}, "final_layer_norm": {"bias": 0.01202392578125, "scale": 0.01226806640625}, "layer_norm": {"bias": 0.0150146484375, "scale": 0.01556396484375}}, "4": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.10498046875}, "out_proj": {"bias": 0.015869140625, "kernel": 0.361328125}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1005859375}, "v_proj": {"bias": 0.026123046875, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.36328125}, "output_dense": {"bias": 0.014404296875, "kernel": 0.29296875}}, "final_layer_norm": {"bias": 0.042724609375, "scale": 0.034423828125}, "layer_norm": {"bias": 0.0478515625, "scale": 0.060546875}}, "40": {"attention": {"k_proj": {"bias": 5.269050598144531e-05, "kernel": 0.046875}, "out_proj": {"bias": 0.0025787353515625, "kernel": 0.080078125}, "q_proj": {"bias": 0.0020294189453125, "kernel": 0.0458984375}, "v_proj": {"bias": 0.004150390625, "kernel": 0.07080078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.004302978515625, "kernel": 0.09326171875}, "output_dense": {"bias": 0.0023040771484375, "kernel": 0.060791015625}}, "final_layer_norm": {"bias": 0.0087890625, "scale": 0.011474609375}, "layer_norm": {"bias": 0.00823974609375, "scale": 0.007781982421875}}, "41": {"attention": {"k_proj": {"bias": 4.0531158447265625e-05, "kernel": 0.0673828125}, "out_proj": {"bias": 0.002044677734375, "kernel": 0.087890625}, "q_proj": {"bias": 0.0025787353515625, "kernel": 0.0634765625}, "v_proj": {"bias": 0.00439453125, "kernel": 0.1044921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0035858154296875, "kernel": 0.091796875}, "output_dense": {"bias": 0.00168609619140625, "kernel": 0.064453125}}, "final_layer_norm": {"bias": 0.00921630859375, "scale": 0.0106201171875}, "layer_norm": {"bias": 0.010986328125, "scale": 0.0128173828125}}, "42": {"attention": {"k_proj": {"bias": 1.1801719665527344e-05, "kernel": 0.02099609375}, "out_proj": {"bias": 0.001678466796875, "kernel": 0.048828125}, "q_proj": {"bias": 0.0009307861328125, "kernel": 0.02197265625}, "v_proj": {"bias": 0.002349853515625, "kernel": 0.0478515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.002685546875, "kernel": 0.07421875}, "output_dense": {"bias": 0.0014801025390625, "kernel": 0.053955078125}}, "final_layer_norm": {"bias": 0.0059814453125, "scale": 0.00885009765625}, "layer_norm": {"bias": 0.0045166015625, "scale": 0.00439453125}}, "43": {"attention": {"k_proj": {"bias": 8.046627044677734e-06, "kernel": 0.01806640625}, "out_proj": {"bias": 0.0015106201171875, "kernel": 0.035888671875}, "q_proj": {"bias": 0.00095367431640625, "kernel": 0.0208740234375}, "v_proj": {"bias": 0.00171661376953125, "kernel": 0.031005859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.002777099609375, "kernel": 0.083984375}, "output_dense": {"bias": 0.00125885009765625, "kernel": 0.052734375}}, "final_layer_norm": {"bias": 0.00689697265625, "scale": 0.007110595703125}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.005706787109375}}, "44": {"attention": {"k_proj": {"bias": 1.3113021850585938e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.001312255859375, "kernel": 0.033447265625}, "q_proj": {"bias": 0.000946044921875, "kernel": 0.021484375}, "v_proj": {"bias": 0.0017547607421875, "kernel": 0.03515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0023956298828125, "kernel": 0.0771484375}, "output_dense": {"bias": 0.001129150390625, "kernel": 0.052001953125}}, "final_layer_norm": {"bias": 0.005706787109375, "scale": 0.005859375}, "layer_norm": {"bias": 0.0042724609375, "scale": 0.004730224609375}}, "45": {"attention": {"k_proj": {"bias": 1.4424324035644531e-05, "kernel": 0.01239013671875}, "out_proj": {"bias": 0.00110626220703125, "kernel": 0.0267333984375}, "q_proj": {"bias": 0.00136566162109375, "kernel": 0.029541015625}, "v_proj": {"bias": 0.00142669677734375, "kernel": 0.027587890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0019989013671875, "kernel": 0.06396484375}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.05615234375}}, "final_layer_norm": {"bias": 0.006072998046875, "scale": 0.006591796875}, "layer_norm": {"bias": 0.004791259765625, "scale": 0.00494384765625}}, "46": {"attention": {"k_proj": {"bias": 5.745887756347656e-05, "kernel": 0.006439208984375}, "out_proj": {"bias": 0.000888824462890625, "kernel": 0.0289306640625}, "q_proj": {"bias": 0.000591278076171875, "kernel": 0.011962890625}, "v_proj": {"bias": 0.0010986328125, "kernel": 0.0233154296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00141143798828125, "kernel": 0.03955078125}, "output_dense": {"bias": 0.000873565673828125, "kernel": 0.0478515625}}, "final_layer_norm": {"bias": 0.00433349609375, "scale": 0.00433349609375}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.003814697265625}}, "47": {"attention": {"k_proj": {"bias": 0.00011301040649414062, "kernel": 0.003997802734375}, "out_proj": {"bias": 0.000896453857421875, "kernel": 0.06640625}, "q_proj": {"bias": 0.00014591217041015625, "kernel": 0.00286865234375}, "v_proj": {"bias": 0.00118255615234375, "kernel": 0.0230712890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0010528564453125, "kernel": 0.0252685546875}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.1787109375}}, "final_layer_norm": {"bias": 0.005950927734375, "scale": 0.00677490234375}, "layer_norm": {"bias": 0.005859375, "scale": 0.005767822265625}}, "5": {"attention": {"k_proj": {"bias": 6.4849853515625e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0159912109375, "kernel": 0.203125}, "q_proj": {"bias": 0.007598876953125, "kernel": 0.12158203125}, "v_proj": {"bias": 0.02685546875, "kernel": 0.1953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.33984375}, "output_dense": {"bias": 0.0147705078125, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.037841796875}, "layer_norm": {"bias": 0.052734375, "scale": 0.04736328125}}, "6": {"attention": {"k_proj": {"bias": 7.152557373046875e-05, "kernel": 0.1318359375}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.349609375}, "q_proj": {"bias": 0.00823974609375, "kernel": 0.119140625}, "v_proj": {"bias": 0.02685546875, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.341796875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.058349609375}}, "7": {"attention": {"k_proj": {"bias": 7.62939453125e-05, "kernel": 0.1328125}, "out_proj": {"bias": 0.0150146484375, "kernel": 0.349609375}, "q_proj": {"bias": 0.00921630859375, "kernel": 0.126953125}, "v_proj": {"bias": 0.0255126953125, "kernel": 0.30859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.33984375}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.038818359375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.05126953125, "scale": 0.050048828125}}, "8": {"attention": {"k_proj": {"bias": 7.915496826171875e-05, "kernel": 0.12060546875}, "out_proj": {"bias": 0.01507568359375, "kernel": 0.302734375}, "q_proj": {"bias": 0.007568359375, "kernel": 0.11474609375}, "v_proj": {"bias": 0.0262451171875, "kernel": 0.27734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.361328125}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.033447265625}, "layer_norm": {"bias": 0.0498046875, "scale": 0.06103515625}}, "9": {"attention": {"k_proj": {"bias": 0.00011777877807617188, "kernel": 0.1513671875}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.416015625}, "q_proj": {"bias": 0.00830078125, "kernel": 0.1376953125}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.40234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0213623046875, "kernel": 0.3515625}, "output_dense": {"bias": 0.0135498046875, "kernel": 0.279296875}}, "final_layer_norm": {"bias": 0.037109375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04443359375, "scale": 0.04833984375}}}, "pos_conv_embed": {"conv": {"bias": 0.034912109375, "weight_g": 0.044189453125, "weight_v": 0.287109375}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.1337890625, "scale": 0.1611328125}, "projection": {"bias": 0.0556640625, "kernel": 1.0546875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.7824921607971191, "kernel": 55.72966766357422}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 59.6768798828125, "scale": 74.17054748535156}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.37033993005752563, "kernel": 27.536663055419922}, "out_proj": {"bias": 1.6469175815582275, "kernel": 26.147050857543945}, "q_proj": {"bias": 1.5330281257629395, "kernel": 27.813282012939453}, "v_proj": {"bias": 0.44783300161361694, "kernel": 26.55841064453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9835489988327026, "kernel": 100.66567993164062}, "output_dense": {"bias": 1.116748571395874, "kernel": 96.76679992675781}}, "final_layer_norm": {"bias": 1.335214376449585, "scale": 19.85782241821289}, "layer_norm": {"bias": 2.923041343688965, "scale": 15.398418426513672}}, "1": {"attention": {"k_proj": {"bias": 0.3767347037792206, "kernel": 41.013240814208984}, "out_proj": {"bias": 1.3653483390808105, "kernel": 43.371070861816406}, "q_proj": {"bias": 3.0925614833831787, "kernel": 41.05661392211914}, "v_proj": {"bias": 0.2924947738647461, "kernel": 41.61189270019531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9882662296295166, "kernel": 98.7442626953125}, "output_dense": {"bias": 0.8527815341949463, "kernel": 87.83541870117188}}, "final_layer_norm": {"bias": 1.3605934381484985, "scale": 19.084806442260742}, "layer_norm": {"bias": 1.9318764209747314, "scale": 17.761367797851562}}, "10": {"attention": {"k_proj": {"bias": 0.4123449921607971, "kernel": 49.44670486450195}, "out_proj": {"bias": 1.3130683898925781, "kernel": 52.27025604248047}, "q_proj": {"bias": 2.48445200920105, "kernel": 49.528873443603516}, "v_proj": {"bias": 0.3416975736618042, "kernel": 52.344085693359375}}, "feed_forward": {"intermediate_dense": {"bias": 1.974550485610962, "kernel": 102.70410919189453}, "output_dense": {"bias": 0.5955485105514526, "kernel": 95.81275939941406}}, "final_layer_norm": {"bias": 2.3762400150299072, "scale": 20.81279754638672}, "layer_norm": {"bias": 1.806241512298584, "scale": 21.429487228393555}}, "11": {"attention": {"k_proj": {"bias": 0.4460787773132324, "kernel": 49.37338638305664}, "out_proj": {"bias": 1.1512463092803955, "kernel": 51.949554443359375}, "q_proj": {"bias": 2.5446064472198486, "kernel": 49.20353698730469}, "v_proj": {"bias": 0.40995872020721436, "kernel": 52.2607536315918}}, "feed_forward": {"intermediate_dense": {"bias": 2.019528388977051, "kernel": 103.56025695800781}, "output_dense": {"bias": 0.5711302757263184, "kernel": 97.53776550292969}}, "final_layer_norm": {"bias": 2.3660812377929688, "scale": 20.919677734375}, "layer_norm": {"bias": 1.7802445888519287, "scale": 22.01519203186035}}, "12": {"attention": {"k_proj": {"bias": 0.4312320351600647, "kernel": 50.07032775878906}, "out_proj": {"bias": 1.126122236251831, "kernel": 51.988826751708984}, "q_proj": {"bias": 2.4090728759765625, "kernel": 49.92729949951172}, "v_proj": {"bias": 0.4024103879928589, "kernel": 52.296756744384766}}, "feed_forward": {"intermediate_dense": {"bias": 2.0548508167266846, "kernel": 104.54823303222656}, "output_dense": {"bias": 0.5548778772354126, "kernel": 99.2693099975586}}, "final_layer_norm": {"bias": 2.2933573722839355, "scale": 20.85626983642578}, "layer_norm": {"bias": 1.8587299585342407, "scale": 22.473487854003906}}, "13": {"attention": {"k_proj": {"bias": 0.4430793821811676, "kernel": 51.76431655883789}, "out_proj": {"bias": 1.1271920204162598, "kernel": 51.86852264404297}, "q_proj": {"bias": 2.359200954437256, "kernel": 51.75225830078125}, "v_proj": {"bias": 0.3906242251396179, "kernel": 51.92781066894531}}, "feed_forward": {"intermediate_dense": {"bias": 2.0941619873046875, "kernel": 105.30806732177734}, "output_dense": {"bias": 0.5719542503356934, "kernel": 99.8712158203125}}, "final_layer_norm": {"bias": 2.2314066886901855, "scale": 21.027400970458984}, "layer_norm": {"bias": 1.9997800588607788, "scale": 22.84510040283203}}, "14": {"attention": {"k_proj": {"bias": 0.43604815006256104, "kernel": 51.92181396484375}, "out_proj": {"bias": 1.2681217193603516, "kernel": 49.762760162353516}, "q_proj": {"bias": 2.4942922592163086, "kernel": 52.049808502197266}, "v_proj": {"bias": 0.36820662021636963, "kernel": 49.26283264160156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1335830688476562, "kernel": 105.94457244873047}, "output_dense": {"bias": 0.6063626408576965, "kernel": 101.24859619140625}}, "final_layer_norm": {"bias": 2.2754664421081543, "scale": 21.145992279052734}, "layer_norm": {"bias": 2.1295526027679443, "scale": 22.584672927856445}}, "15": {"attention": {"k_proj": {"bias": 0.45931702852249146, "kernel": 51.93058776855469}, "out_proj": {"bias": 1.3753983974456787, "kernel": 50.94449234008789}, "q_proj": {"bias": 2.5871663093566895, "kernel": 52.099769592285156}, "v_proj": {"bias": 0.459650456905365, "kernel": 50.5831298828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.132938861846924, "kernel": 105.57443237304688}, "output_dense": {"bias": 0.7701732516288757, "kernel": 101.94094848632812}}, "final_layer_norm": {"bias": 2.327320098876953, "scale": 21.192947387695312}, "layer_norm": {"bias": 2.4148712158203125, "scale": 23.526634216308594}}, "16": {"attention": {"k_proj": {"bias": 0.4008745551109314, "kernel": 51.772621154785156}, "out_proj": {"bias": 1.27531099319458, "kernel": 50.134521484375}, "q_proj": {"bias": 2.667466163635254, "kernel": 51.75814437866211}, "v_proj": {"bias": 0.3768249750137329, "kernel": 49.78093719482422}}, "feed_forward": {"intermediate_dense": {"bias": 2.1094985008239746, "kernel": 106.10227966308594}, "output_dense": {"bias": 0.7860437631607056, "kernel": 102.66590881347656}}, "final_layer_norm": {"bias": 2.337951421737671, "scale": 21.583194732666016}, "layer_norm": {"bias": 2.283249855041504, "scale": 22.168060302734375}}, "17": {"attention": {"k_proj": {"bias": 0.3966267704963684, "kernel": 51.728885650634766}, "out_proj": {"bias": 1.2151354551315308, "kernel": 49.4556884765625}, "q_proj": {"bias": 2.714320182800293, "kernel": 51.81880187988281}, "v_proj": {"bias": 0.42661017179489136, "kernel": 49.11927032470703}}, "feed_forward": {"intermediate_dense": {"bias": 2.1101765632629395, "kernel": 107.14872741699219}, "output_dense": {"bias": 0.8218655586242676, "kernel": 103.06423950195312}}, "final_layer_norm": {"bias": 2.383938789367676, "scale": 22.070323944091797}, "layer_norm": {"bias": 2.222898483276367, "scale": 21.219982147216797}}, "18": {"attention": {"k_proj": {"bias": 0.4409676194190979, "kernel": 52.41611099243164}, "out_proj": {"bias": 1.3447906970977783, "kernel": 50.491905212402344}, "q_proj": {"bias": 2.614685535430908, "kernel": 52.796600341796875}, "v_proj": {"bias": 0.4518332779407501, "kernel": 50.001895904541016}}, "feed_forward": {"intermediate_dense": {"bias": 2.144195556640625, "kernel": 107.41338348388672}, "output_dense": {"bias": 0.9481453895568848, "kernel": 104.72514343261719}}, "final_layer_norm": {"bias": 2.5390124320983887, "scale": 22.15178680419922}, "layer_norm": {"bias": 2.424910068511963, "scale": 23.585906982421875}}, "19": {"attention": {"k_proj": {"bias": 0.38193291425704956, "kernel": 51.511146545410156}, "out_proj": {"bias": 1.3303101062774658, "kernel": 50.10035705566406}, "q_proj": {"bias": 2.930327892303467, "kernel": 51.865638732910156}, "v_proj": {"bias": 0.4086824655532837, "kernel": 49.38078308105469}}, "feed_forward": {"intermediate_dense": {"bias": 2.1912901401519775, "kernel": 107.95254516601562}, "output_dense": {"bias": 1.0248571634292603, "kernel": 105.65098571777344}}, "final_layer_norm": {"bias": 2.4923481941223145, "scale": 22.505674362182617}, "layer_norm": {"bias": 2.2888314723968506, "scale": 22.31826400756836}}, "2": {"attention": {"k_proj": {"bias": 0.454792320728302, "kernel": 47.77275085449219}, "out_proj": {"bias": 1.256988525390625, "kernel": 45.969764709472656}, "q_proj": {"bias": 3.2510807514190674, "kernel": 47.61664581298828}, "v_proj": {"bias": 0.339598685503006, "kernel": 45.72273254394531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9737317562103271, "kernel": 103.32754516601562}, "output_dense": {"bias": 0.7398276329040527, "kernel": 91.11263275146484}}, "final_layer_norm": {"bias": 1.5421981811523438, "scale": 21.561111450195312}, "layer_norm": {"bias": 1.7081801891326904, "scale": 20.852447509765625}}, "20": {"attention": {"k_proj": {"bias": 0.4067543148994446, "kernel": 51.605438232421875}, "out_proj": {"bias": 1.359946370124817, "kernel": 49.45553207397461}, "q_proj": {"bias": 2.8498687744140625, "kernel": 52.224571228027344}, "v_proj": {"bias": 0.36227869987487793, "kernel": 48.43864822387695}}, "feed_forward": {"intermediate_dense": {"bias": 2.1725549697875977, "kernel": 109.17405700683594}, "output_dense": {"bias": 1.1388803720474243, "kernel": 106.40528106689453}}, "final_layer_norm": {"bias": 2.435314655303955, "scale": 23.4317626953125}, "layer_norm": {"bias": 2.231672525405884, "scale": 22.230525970458984}}, "21": {"attention": {"k_proj": {"bias": 0.4161534905433655, "kernel": 51.942527770996094}, "out_proj": {"bias": 1.403618335723877, "kernel": 49.51059341430664}, "q_proj": {"bias": 2.7690629959106445, "kernel": 52.67078399658203}, "v_proj": {"bias": 0.41060006618499756, "kernel": 48.64883041381836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2174296379089355, "kernel": 109.5155029296875}, "output_dense": {"bias": 1.253208041191101, "kernel": 106.88243865966797}}, "final_layer_norm": {"bias": 2.4632763862609863, "scale": 23.175764083862305}, "layer_norm": {"bias": 2.2785892486572266, "scale": 22.234222412109375}}, "22": {"attention": {"k_proj": {"bias": 0.45357397198677063, "kernel": 52.54576110839844}, "out_proj": {"bias": 1.349219560623169, "kernel": 49.533172607421875}, "q_proj": {"bias": 2.8105549812316895, "kernel": 52.86981201171875}, "v_proj": {"bias": 0.3973655700683594, "kernel": 49.33363342285156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1619315147399902, "kernel": 109.95498657226562}, "output_dense": {"bias": 1.3076066970825195, "kernel": 106.3852310180664}}, "final_layer_norm": {"bias": 2.3642821311950684, "scale": 22.684059143066406}, "layer_norm": {"bias": 2.3316237926483154, "scale": 21.545879364013672}}, "23": {"attention": {"k_proj": {"bias": 0.4928613007068634, "kernel": 53.47669219970703}, "out_proj": {"bias": 1.564335823059082, "kernel": 50.98707580566406}, "q_proj": {"bias": 2.7065773010253906, "kernel": 53.582611083984375}, "v_proj": {"bias": 0.5810648202896118, "kernel": 51.54853820800781}}, "feed_forward": {"intermediate_dense": {"bias": 2.131969690322876, "kernel": 109.86410522460938}, "output_dense": {"bias": 1.2769315242767334, "kernel": 107.37890625}}, "final_layer_norm": {"bias": 2.767916679382324, "scale": 22.887813568115234}, "layer_norm": {"bias": 2.824352264404297, "scale": 23.373172760009766}}, "24": {"attention": {"k_proj": {"bias": 0.46056002378463745, "kernel": 52.424072265625}, "out_proj": {"bias": 1.6070430278778076, "kernel": 52.50334167480469}, "q_proj": {"bias": 2.828113079071045, "kernel": 52.40515899658203}, "v_proj": {"bias": 0.5424190163612366, "kernel": 52.51116180419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.2367913722991943, "kernel": 109.35035705566406}, "output_dense": {"bias": 1.3016372919082642, "kernel": 110.30095672607422}}, "final_layer_norm": {"bias": 2.83841872215271, "scale": 22.964658737182617}, "layer_norm": {"bias": 2.56215763092041, "scale": 22.983924865722656}}, "25": {"attention": {"k_proj": {"bias": 0.42509031295776367, "kernel": 52.730464935302734}, "out_proj": {"bias": 1.363797664642334, "kernel": 50.5806884765625}, "q_proj": {"bias": 2.9342763423919678, "kernel": 52.548744201660156}, "v_proj": {"bias": 0.6404213309288025, "kernel": 51.0885009765625}}, "feed_forward": {"intermediate_dense": {"bias": 2.1367578506469727, "kernel": 109.70021057128906}, "output_dense": {"bias": 1.1017413139343262, "kernel": 110.27072143554688}}, "final_layer_norm": {"bias": 2.5763301849365234, "scale": 23.494670867919922}, "layer_norm": {"bias": 2.683134078979492, "scale": 21.88357925415039}}, "26": {"attention": {"k_proj": {"bias": 0.4836847186088562, "kernel": 53.01764678955078}, "out_proj": {"bias": 1.2433912754058838, "kernel": 51.37077331542969}, "q_proj": {"bias": 2.943906784057617, "kernel": 52.80891036987305}, "v_proj": {"bias": 0.5064959526062012, "kernel": 52.004638671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.2763516902923584, "kernel": 109.44652557373047}, "output_dense": {"bias": 1.0912110805511475, "kernel": 107.40899658203125}}, "final_layer_norm": {"bias": 2.1937994956970215, "scale": 22.433353424072266}, "layer_norm": {"bias": 2.497119903564453, "scale": 22.19057273864746}}, "27": {"attention": {"k_proj": {"bias": 0.5808594226837158, "kernel": 53.76898956298828}, "out_proj": {"bias": 1.5447406768798828, "kernel": 52.95805358886719}, "q_proj": {"bias": 2.703345775604248, "kernel": 53.69578552246094}, "v_proj": {"bias": 0.6748642325401306, "kernel": 53.388118743896484}}, "feed_forward": {"intermediate_dense": {"bias": 2.404933452606201, "kernel": 107.8713150024414}, "output_dense": {"bias": 0.9485896825790405, "kernel": 107.17198181152344}}, "final_layer_norm": {"bias": 2.5252954959869385, "scale": 21.88959503173828}, "layer_norm": {"bias": 2.6147172451019287, "scale": 23.32440948486328}}, "28": {"attention": {"k_proj": {"bias": 0.5901432037353516, "kernel": 54.482521057128906}, "out_proj": {"bias": 1.5367379188537598, "kernel": 53.31493377685547}, "q_proj": {"bias": 2.9472482204437256, "kernel": 54.1741943359375}, "v_proj": {"bias": 0.5131911039352417, "kernel": 53.759761810302734}}, "feed_forward": {"intermediate_dense": {"bias": 2.3475265502929688, "kernel": 107.87416076660156}, "output_dense": {"bias": 0.8224154710769653, "kernel": 109.1680908203125}}, "final_layer_norm": {"bias": 2.425306797027588, "scale": 22.337677001953125}, "layer_norm": {"bias": 2.0914058685302734, "scale": 23.993711471557617}}, "29": {"attention": {"k_proj": {"bias": 0.46781182289123535, "kernel": 51.12034606933594}, "out_proj": {"bias": 1.5021522045135498, "kernel": 55.685630798339844}, "q_proj": {"bias": 2.809702157974243, "kernel": 51.00274658203125}, "v_proj": {"bias": 0.4760415554046631, "kernel": 55.703304290771484}}, "feed_forward": {"intermediate_dense": {"bias": 2.297222137451172, "kernel": 108.01033020019531}, "output_dense": {"bias": 0.9597339630126953, "kernel": 113.12825012207031}}, "final_layer_norm": {"bias": 2.5980498790740967, "scale": 23.459980010986328}, "layer_norm": {"bias": 2.245180130004883, "scale": 25.39927864074707}}, "3": {"attention": {"k_proj": {"bias": 0.45006245374679565, "kernel": 52.03215789794922}, "out_proj": {"bias": 1.4254932403564453, "kernel": 48.60858917236328}, "q_proj": {"bias": 2.8560738563537598, "kernel": 52.312644958496094}, "v_proj": {"bias": 0.3246268630027771, "kernel": 48.768699645996094}}, "feed_forward": {"intermediate_dense": {"bias": 1.9663825035095215, "kernel": 104.83622741699219}, "output_dense": {"bias": 0.6984099745750427, "kernel": 94.07957458496094}}, "final_layer_norm": {"bias": 1.8095453977584839, "scale": 21.664737701416016}, "layer_norm": {"bias": 1.9017157554626465, "scale": 22.739452362060547}}, "30": {"attention": {"k_proj": {"bias": 0.5024805665016174, "kernel": 52.825706481933594}, "out_proj": {"bias": 1.3023658990859985, "kernel": 52.053871154785156}, "q_proj": {"bias": 2.907101631164551, "kernel": 52.91836166381836}, "v_proj": {"bias": 0.49308842420578003, "kernel": 52.49382019042969}}, "feed_forward": {"intermediate_dense": {"bias": 2.2399911880493164, "kernel": 108.17861938476562}, "output_dense": {"bias": 0.9140658378601074, "kernel": 112.09104919433594}}, "final_layer_norm": {"bias": 2.4926414489746094, "scale": 24.492368698120117}, "layer_norm": {"bias": 2.316732168197632, "scale": 24.931156158447266}}, "31": {"attention": {"k_proj": {"bias": 0.5412741899490356, "kernel": 51.240806579589844}, "out_proj": {"bias": 1.2333163022994995, "kernel": 52.19988250732422}, "q_proj": {"bias": 2.6581294536590576, "kernel": 51.346168518066406}, "v_proj": {"bias": 0.5469827651977539, "kernel": 52.432586669921875}}, "feed_forward": {"intermediate_dense": {"bias": 2.3097352981567383, "kernel": 106.72758483886719}, "output_dense": {"bias": 1.0891624689102173, "kernel": 109.24717712402344}}, "final_layer_norm": {"bias": 2.2962756156921387, "scale": 24.31252670288086}, "layer_norm": {"bias": 2.3430848121643066, "scale": 24.590187072753906}}, "32": {"attention": {"k_proj": {"bias": 0.4704548716545105, "kernel": 50.39933776855469}, "out_proj": {"bias": 1.2453913688659668, "kernel": 51.57465362548828}, "q_proj": {"bias": 2.8450098037719727, "kernel": 50.34847640991211}, "v_proj": {"bias": 0.419519305229187, "kernel": 51.972320556640625}}, "feed_forward": {"intermediate_dense": {"bias": 2.2569355964660645, "kernel": 105.33018493652344}, "output_dense": {"bias": 1.146787166595459, "kernel": 108.35574340820312}}, "final_layer_norm": {"bias": 2.31538724899292, "scale": 24.518985748291016}, "layer_norm": {"bias": 2.417579174041748, "scale": 24.991830825805664}}, "33": {"attention": {"k_proj": {"bias": 0.48390907049179077, "kernel": 50.28266143798828}, "out_proj": {"bias": 1.280959129333496, "kernel": 51.30557632446289}, "q_proj": {"bias": 2.998173713684082, "kernel": 50.253868103027344}, "v_proj": {"bias": 0.4416005611419678, "kernel": 51.71035385131836}}, "feed_forward": {"intermediate_dense": {"bias": 2.279946804046631, "kernel": 103.67684173583984}, "output_dense": {"bias": 1.1754591464996338, "kernel": 106.81501007080078}}, "final_layer_norm": {"bias": 2.257889747619629, "scale": 24.207740783691406}, "layer_norm": {"bias": 2.5865607261657715, "scale": 25.067882537841797}}, "34": {"attention": {"k_proj": {"bias": 0.45390456914901733, "kernel": 49.26523208618164}, "out_proj": {"bias": 1.5265402793884277, "kernel": 52.46200942993164}, "q_proj": {"bias": 2.913527011871338, "kernel": 49.268951416015625}, "v_proj": {"bias": 0.4023621678352356, "kernel": 52.534793853759766}}, "feed_forward": {"intermediate_dense": {"bias": 2.3745017051696777, "kernel": 102.21138000488281}, "output_dense": {"bias": 1.1227428913116455, "kernel": 105.72161102294922}}, "final_layer_norm": {"bias": 2.20273494720459, "scale": 23.640857696533203}, "layer_norm": {"bias": 2.615731716156006, "scale": 25.498104095458984}}, "35": {"attention": {"k_proj": {"bias": 0.5336894989013672, "kernel": 51.04902648925781}, "out_proj": {"bias": 1.4900906085968018, "kernel": 51.15006637573242}, "q_proj": {"bias": 2.573650360107422, "kernel": 51.323951721191406}, "v_proj": {"bias": 0.4889468550682068, "kernel": 51.176795959472656}}, "feed_forward": {"intermediate_dense": {"bias": 2.502547264099121, "kernel": 100.7535400390625}, "output_dense": {"bias": 1.0254027843475342, "kernel": 104.25007629394531}}, "final_layer_norm": {"bias": 2.295243263244629, "scale": 23.61981964111328}, "layer_norm": {"bias": 2.499863862991333, "scale": 26.093618392944336}}, "36": {"attention": {"k_proj": {"bias": 0.4473738670349121, "kernel": 48.319740295410156}, "out_proj": {"bias": 1.5177178382873535, "kernel": 52.26478958129883}, "q_proj": {"bias": 2.6124308109283447, "kernel": 48.236427307128906}, "v_proj": {"bias": 0.39316776394844055, "kernel": 52.66499328613281}}, "feed_forward": {"intermediate_dense": {"bias": 2.3669564723968506, "kernel": 99.60517120361328}, "output_dense": {"bias": 1.0260361433029175, "kernel": 103.68067932128906}}, "final_layer_norm": {"bias": 2.044203519821167, "scale": 24.13540267944336}, "layer_norm": {"bias": 2.2894434928894043, "scale": 25.63806915283203}}, "37": {"attention": {"k_proj": {"bias": 0.6240901350975037, "kernel": 47.300960540771484}, "out_proj": {"bias": 1.7604304552078247, "kernel": 52.189971923828125}, "q_proj": {"bias": 2.3819239139556885, "kernel": 47.316253662109375}, "v_proj": {"bias": 0.38518577814102173, "kernel": 52.32656478881836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2694783210754395, "kernel": 98.56558227539062}, "output_dense": {"bias": 1.0087021589279175, "kernel": 103.13001251220703}}, "final_layer_norm": {"bias": 1.7903382778167725, "scale": 24.500085830688477}, "layer_norm": {"bias": 2.238807439804077, "scale": 25.563133239746094}}, "38": {"attention": {"k_proj": {"bias": 0.722891092300415, "kernel": 45.45281219482422}, "out_proj": {"bias": 1.4463956356048584, "kernel": 51.495086669921875}, "q_proj": {"bias": 2.2622861862182617, "kernel": 45.454437255859375}, "v_proj": {"bias": 0.42936205863952637, "kernel": 51.565425872802734}}, "feed_forward": {"intermediate_dense": {"bias": 2.2012171745300293, "kernel": 96.4342041015625}, "output_dense": {"bias": 0.9817801713943481, "kernel": 101.30387115478516}}, "final_layer_norm": {"bias": 1.7825044393539429, "scale": 25.214679718017578}, "layer_norm": {"bias": 2.4084482192993164, "scale": 26.425636291503906}}, "39": {"attention": {"k_proj": {"bias": 0.7209377884864807, "kernel": 45.24742889404297}, "out_proj": {"bias": 1.7097458839416504, "kernel": 51.329002380371094}, "q_proj": {"bias": 2.1152358055114746, "kernel": 45.53700637817383}, "v_proj": {"bias": 0.4246286153793335, "kernel": 51.27728271484375}}, "feed_forward": {"intermediate_dense": {"bias": 2.177624225616455, "kernel": 94.23902893066406}, "output_dense": {"bias": 1.0458879470825195, "kernel": 101.17718505859375}}, "final_layer_norm": {"bias": 1.753927230834961, "scale": 25.785865783691406}, "layer_norm": {"bias": 2.33198881149292, "scale": 26.942312240600586}}, "4": {"attention": {"k_proj": {"bias": 0.44548165798187256, "kernel": 54.6234130859375}, "out_proj": {"bias": 1.652343988418579, "kernel": 50.16497039794922}, "q_proj": {"bias": 2.615248918533325, "kernel": 54.93098068237305}, "v_proj": {"bias": 0.34892427921295166, "kernel": 50.320098876953125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531786441802979, "kernel": 104.50968933105469}, "output_dense": {"bias": 0.8575068712234497, "kernel": 95.54541015625}}, "final_layer_norm": {"bias": 1.9920002222061157, "scale": 21.200180053710938}, "layer_norm": {"bias": 2.054612159729004, "scale": 23.620338439941406}}, "40": {"attention": {"k_proj": {"bias": 0.6590453386306763, "kernel": 44.198089599609375}, "out_proj": {"bias": 1.6252505779266357, "kernel": 49.55699920654297}, "q_proj": {"bias": 1.9674756526947021, "kernel": 44.89208221435547}, "v_proj": {"bias": 0.4587768614292145, "kernel": 49.23614501953125}}, "feed_forward": {"intermediate_dense": {"bias": 2.0333969593048096, "kernel": 92.16896057128906}, "output_dense": {"bias": 1.087776780128479, "kernel": 98.40738677978516}}, "final_layer_norm": {"bias": 1.7852704524993896, "scale": 25.04292106628418}, "layer_norm": {"bias": 2.2756104469299316, "scale": 26.40799903869629}}, "41": {"attention": {"k_proj": {"bias": 1.7133712768554688, "kernel": 41.96858596801758}, "out_proj": {"bias": 1.3790823221206665, "kernel": 51.25593566894531}, "q_proj": {"bias": 1.71382737159729, "kernel": 42.56317901611328}, "v_proj": {"bias": 0.4695759415626526, "kernel": 50.369529724121094}}, "feed_forward": {"intermediate_dense": {"bias": 2.110393524169922, "kernel": 88.92567443847656}, "output_dense": {"bias": 1.1446669101715088, "kernel": 97.37409973144531}}, "final_layer_norm": {"bias": 2.23917293548584, "scale": 28.507984161376953}, "layer_norm": {"bias": 2.22525691986084, "scale": 28.246891021728516}}, "42": {"attention": {"k_proj": {"bias": 0.8601109981536865, "kernel": 38.31235885620117}, "out_proj": {"bias": 1.4427157640457153, "kernel": 45.07648849487305}, "q_proj": {"bias": 1.549715280532837, "kernel": 39.524009704589844}, "v_proj": {"bias": 0.6933339834213257, "kernel": 43.49076461791992}}, "feed_forward": {"intermediate_dense": {"bias": 1.9107489585876465, "kernel": 88.009765625}, "output_dense": {"bias": 1.1978566646575928, "kernel": 95.7593994140625}}, "final_layer_norm": {"bias": 1.9227323532104492, "scale": 29.817535400390625}, "layer_norm": {"bias": 1.6761282682418823, "scale": 26.810440063476562}}, "43": {"attention": {"k_proj": {"bias": 1.247081995010376, "kernel": 34.694725036621094}, "out_proj": {"bias": 1.4174811840057373, "kernel": 41.36320495605469}, "q_proj": {"bias": 1.3773530721664429, "kernel": 35.38981628417969}, "v_proj": {"bias": 0.5787136554718018, "kernel": 39.29212951660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8936835527420044, "kernel": 87.073974609375}, "output_dense": {"bias": 0.9419379234313965, "kernel": 93.74283599853516}}, "final_layer_norm": {"bias": 1.99924898147583, "scale": 32.0491943359375}, "layer_norm": {"bias": 1.7940990924835205, "scale": 25.131242752075195}}, "44": {"attention": {"k_proj": {"bias": 2.5188145637512207, "kernel": 35.16314697265625}, "out_proj": {"bias": 1.1667875051498413, "kernel": 45.019126892089844}, "q_proj": {"bias": 1.317202091217041, "kernel": 35.58300018310547}, "v_proj": {"bias": 0.38874924182891846, "kernel": 44.14718246459961}}, "feed_forward": {"intermediate_dense": {"bias": 1.9462898969650269, "kernel": 86.07953643798828}, "output_dense": {"bias": 0.859266996383667, "kernel": 91.58765411376953}}, "final_layer_norm": {"bias": 2.0454273223876953, "scale": 34.2881965637207}, "layer_norm": {"bias": 1.6815991401672363, "scale": 25.14142608642578}}, "45": {"attention": {"k_proj": {"bias": 2.081407308578491, "kernel": 34.86139678955078}, "out_proj": {"bias": 1.0356104373931885, "kernel": 48.59937286376953}, "q_proj": {"bias": 1.402512788772583, "kernel": 35.03264617919922}, "v_proj": {"bias": 0.4231463074684143, "kernel": 48.76853942871094}}, "feed_forward": {"intermediate_dense": {"bias": 2.016927719116211, "kernel": 82.93773651123047}, "output_dense": {"bias": 0.9764893054962158, "kernel": 87.24796295166016}}, "final_layer_norm": {"bias": 1.9180456399917603, "scale": 33.143672943115234}, "layer_norm": {"bias": 1.5726068019866943, "scale": 23.782546997070312}}, "46": {"attention": {"k_proj": {"bias": 1.5659263134002686, "kernel": 35.878021240234375}, "out_proj": {"bias": 0.8182340264320374, "kernel": 51.16078186035156}, "q_proj": {"bias": 1.5642974376678467, "kernel": 36.18907165527344}, "v_proj": {"bias": 0.4092414081096649, "kernel": 51.89159393310547}}, "feed_forward": {"intermediate_dense": {"bias": 2.0093321800231934, "kernel": 77.47581481933594}, "output_dense": {"bias": 1.1406863927841187, "kernel": 77.73695373535156}}, "final_layer_norm": {"bias": 1.8108854293823242, "scale": 28.70657730102539}, "layer_norm": {"bias": 1.3991491794586182, "scale": 22.808137893676758}}, "47": {"attention": {"k_proj": {"bias": 0.6173280477523804, "kernel": 38.678985595703125}, "out_proj": {"bias": 0.6758822202682495, "kernel": 46.45281219482422}, "q_proj": {"bias": 1.7084776163101196, "kernel": 39.426841735839844}, "v_proj": {"bias": 0.4932914674282074, "kernel": 47.617279052734375}}, "feed_forward": {"intermediate_dense": {"bias": 1.986911654472351, "kernel": 75.47482299804688}, "output_dense": {"bias": 0.6346586346626282, "kernel": 72.82707214355469}}, "final_layer_norm": {"bias": 1.1888140439987183, "scale": 23.650447845458984}, "layer_norm": {"bias": 1.2521969079971313, "scale": 20.66573715209961}}, "5": {"attention": {"k_proj": {"bias": 0.42588678002357483, "kernel": 50.1945686340332}, "out_proj": {"bias": 1.6038882732391357, "kernel": 51.2144889831543}, "q_proj": {"bias": 2.7522244453430176, "kernel": 50.37500762939453}, "v_proj": {"bias": 0.3343381881713867, "kernel": 51.71652603149414}}, "feed_forward": {"intermediate_dense": {"bias": 1.8887722492218018, "kernel": 104.60663604736328}, "output_dense": {"bias": 0.8976269960403442, "kernel": 94.77360534667969}}, "final_layer_norm": {"bias": 2.1965675354003906, "scale": 21.37998390197754}, "layer_norm": {"bias": 2.0435237884521484, "scale": 22.437192916870117}}, "6": {"attention": {"k_proj": {"bias": 0.4843112528324127, "kernel": 51.87700653076172}, "out_proj": {"bias": 1.5925445556640625, "kernel": 50.83113479614258}, "q_proj": {"bias": 2.7889723777770996, "kernel": 52.3514404296875}, "v_proj": {"bias": 0.3247200846672058, "kernel": 51.107398986816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.8638136386871338, "kernel": 103.7142333984375}, "output_dense": {"bias": 0.752193808555603, "kernel": 94.57742309570312}}, "final_layer_norm": {"bias": 2.5145251750946045, "scale": 20.836563110351562}, "layer_norm": {"bias": 2.0285890102386475, "scale": 23.156789779663086}}, "7": {"attention": {"k_proj": {"bias": 0.5048109889030457, "kernel": 51.46453094482422}, "out_proj": {"bias": 1.4398455619812012, "kernel": 51.139068603515625}, "q_proj": {"bias": 2.550907611846924, "kernel": 51.92047119140625}, "v_proj": {"bias": 0.42719271779060364, "kernel": 50.953285217285156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8739449977874756, "kernel": 103.49991607666016}, "output_dense": {"bias": 0.5876641273498535, "kernel": 94.39387512207031}}, "final_layer_norm": {"bias": 2.416801929473877, "scale": 21.010677337646484}, "layer_norm": {"bias": 1.9788501262664795, "scale": 22.19708824157715}}, "8": {"attention": {"k_proj": {"bias": 0.49711495637893677, "kernel": 51.12122344970703}, "out_proj": {"bias": 1.2548246383666992, "kernel": 51.65118408203125}, "q_proj": {"bias": 2.541980504989624, "kernel": 51.03407287597656}, "v_proj": {"bias": 0.35420340299606323, "kernel": 51.662872314453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9298467636108398, "kernel": 103.21916198730469}, "output_dense": {"bias": 0.5482766628265381, "kernel": 93.97574615478516}}, "final_layer_norm": {"bias": 2.3526053428649902, "scale": 20.7393856048584}, "layer_norm": {"bias": 1.9221248626708984, "scale": 22.40435028076172}}, "9": {"attention": {"k_proj": {"bias": 0.5231171250343323, "kernel": 52.01068878173828}, "out_proj": {"bias": 1.4968843460083008, "kernel": 52.671897888183594}, "q_proj": {"bias": 2.4629459381103516, "kernel": 52.26807403564453}, "v_proj": {"bias": 0.38445231318473816, "kernel": 52.86597442626953}}, "feed_forward": {"intermediate_dense": {"bias": 2.026733875274658, "kernel": 101.98421478271484}, "output_dense": {"bias": 0.6828575134277344, "kernel": 94.36962890625}}, "final_layer_norm": {"bias": 2.325080156326294, "scale": 20.160720825195312}, "layer_norm": {"bias": 2.0236480236053467, "scale": 24.083864212036133}}}, "pos_conv_embed": {"conv": {"bias": 5.847014427185059, "weight_g": 9.12463665008545, "weight_v": 93.52015686035156}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.376383304595947, "scale": 16.443069458007812}, "projection": {"bias": 1.8670344352722168, "kernel": 37.218414306640625}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 1.9151924789184704e-05, "train/loss": 0.204779714345932, "train/param_norm": 1241.662353515625, "_runtime": 4765, "_timestamp": 1660121820, "_step": 275600, "_wandb": {"runtime": 4766}} \ No newline at end of file diff --git a/wandb/run-20220810_073735-23avj35z/logs/debug-internal.log b/wandb/run-20220810_073735-23avj35z/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..85cd3551f94bbe83a2d167214a5e18f17911cff4 --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:321db68741142aebf6da6dfd07396d57f1844a38e8782fb191cb8b2f9d6ad8f3 +size 182857 diff --git a/wandb/run-20220810_073735-23avj35z/logs/debug.log b/wandb/run-20220810_073735-23avj35z/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..a909682accfa357cc57ffbe3246f734775d3ba6f --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ac7e803317e52439f444d7689e70325ecbb39546789a4dccfc840ec06a3de97 +size 6204 diff --git a/wandb/run-20220810_073735-23avj35z/run-23avj35z.wandb b/wandb/run-20220810_073735-23avj35z/run-23avj35z.wandb new file mode 100644 index 0000000000000000000000000000000000000000..7f63e688be081771e321f7d72dd68524b9d7f4bc --- /dev/null +++ b/wandb/run-20220810_073735-23avj35z/run-23avj35z.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c7bcca71ea7ef338baa85a6158b1edd93877dbbb7c2d16bb122668a85532e88 +size 772668 diff --git a/wandb/run-20220810_111559-290849gb/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220810_111559-290849gb/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..a3308794e399464a6649e4f73284f7a0586a92ce --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1631 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=0, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220810_111559-290849gb/files/config.yaml b/wandb/run-20220810_111559-290849gb/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36b6c4d6413f90e627e8ab6ec38ecd5090e713a3 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660130159 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220810_111559-290849gb/files/diff.patch b/wandb/run-20220810_111559-290849gb/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..7365edd134e852ac32abfecab9fe7060bf648ebf --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/diff.patch @@ -0,0 +1,52 @@ +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..c11fc15 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,20 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..ad68f93 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220810_111559-290849gb/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..8db277f 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220810_111559-290849gb/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..052e8bb 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220810_111559-290849gb +\ No newline at end of file diff --git a/wandb/run-20220810_111559-290849gb/files/output.log b/wandb/run-20220810_111559-290849gb/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..b4d3b249ce96994945e6c76c74d298745f9142c6 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5d175f5b339eb2b7e07c06a41cf262a197646be241ea83c5bc595ad9c114374 +size 209075 diff --git a/wandb/run-20220810_111559-290849gb/files/requirements.txt b/wandb/run-20220810_111559-290849gb/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220810_111559-290849gb/files/wandb-metadata.json b/wandb/run-20220810_111559-290849gb/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..98bd1183d2c0b0bf0f68c47d646323d68e45f092 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/wandb-metadata.json @@ -0,0 +1,70 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-10T11:16:02.847385", + "startedAt": "2022-08-10T11:15:59.241818", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=./", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=full_mixed", + "--matmul_precision=bfloat16", + "--multisteps", + "--learning_rate=6.394633237505332e-05", + "--skip_steps=275000", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=30.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--do_lower_case", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)", + "--remove_punctuation" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220810_111559-290849gb/files/wandb-summary.json b/wandb/run-20220810_111559-290849gb/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..f88df35c7282944009866b2a6e96cd4663f59cc5 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/files/wandb-summary.json @@ -0,0 +1 @@ +{"train/grad_norm": 6.5625, "layer_grad_norm/": {"lm_head": {"bias": 0.031982421875, "kernel": 4.625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.0556640625, "scale": 0.06103515625}, "layers": {"0": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.04150390625, "kernel": 0.2431640625}, "q_proj": {"bias": 0.002899169921875, "kernel": 0.031005859375}, "v_proj": {"bias": 0.037109375, "kernel": 0.265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.04443359375, "kernel": 0.515625}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.439453125}}, "final_layer_norm": {"bias": 0.146484375, "scale": 0.322265625}, "layer_norm": {"bias": 0.0703125, "scale": 0.07080078125}}, "1": {"attention": {"k_proj": {"bias": 3.4332275390625e-05, "kernel": 0.03955078125}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.134765625}, "q_proj": {"bias": 0.0035247802734375, "kernel": 0.0439453125}, "v_proj": {"bias": 0.02880859375, "kernel": 0.111328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.3359375}, "output_dense": {"bias": 0.0157470703125, "kernel": 0.259765625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.05712890625}, "layer_norm": {"bias": 0.05712890625, "scale": 0.039794921875}}, "10": {"attention": {"k_proj": {"bias": 3.600120544433594e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.01416015625, "kernel": 0.2001953125}, "q_proj": {"bias": 0.0078125, "kernel": 0.12255859375}, "v_proj": {"bias": 0.022705078125, "kernel": 0.2001953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.328125}, "output_dense": {"bias": 0.013671875, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04931640625, "scale": 0.0341796875}}, "11": {"attention": {"k_proj": {"bias": 8.344650268554688e-05, "kernel": 0.158203125}, "out_proj": {"bias": 0.0142822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.0087890625, "kernel": 0.130859375}, "v_proj": {"bias": 0.024658203125, "kernel": 0.28515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.310546875}, "output_dense": {"bias": 0.013916015625, "kernel": 0.244140625}}, "final_layer_norm": {"bias": 0.03271484375, "scale": 0.0308837890625}, "layer_norm": {"bias": 0.05029296875, "scale": 0.0439453125}}, "12": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.244140625}, "q_proj": {"bias": 0.0081787109375, "kernel": 0.1162109375}, "v_proj": {"bias": 0.023681640625, "kernel": 0.2294921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.32421875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.255859375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.04248046875}, "layer_norm": {"bias": 0.046630859375, "scale": 0.0546875}}, "13": {"attention": {"k_proj": {"bias": 0.00012493133544921875, "kernel": 0.15625}, "out_proj": {"bias": 0.01519775390625, "kernel": 0.330078125}, "q_proj": {"bias": 0.0111083984375, "kernel": 0.158203125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.34375}, "output_dense": {"bias": 0.01513671875, "kernel": 0.3125}}, "final_layer_norm": {"bias": 0.040283203125, "scale": 0.032958984375}, "layer_norm": {"bias": 0.051513671875, "scale": 0.091796875}}, "14": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.015625, "kernel": 0.2412109375}, "q_proj": {"bias": 0.006256103515625, "kernel": 0.099609375}, "v_proj": {"bias": 0.0235595703125, "kernel": 0.2275390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0257568359375, "kernel": 0.39453125}, "output_dense": {"bias": 0.015380859375, "kernel": 0.33984375}}, "final_layer_norm": {"bias": 0.05126953125, "scale": 0.05517578125}, "layer_norm": {"bias": 0.041748046875, "scale": 0.03076171875}}, "15": {"attention": {"k_proj": {"bias": 0.0003070831298828125, "kernel": 0.1806640625}, "out_proj": {"bias": 0.015625, "kernel": 0.5078125}, "q_proj": {"bias": 0.0106201171875, "kernel": 0.173828125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.361328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.376953125}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.349609375}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.033447265625}, "layer_norm": {"bias": 0.048095703125, "scale": 0.072265625}}, "16": {"attention": {"k_proj": {"bias": 6.389617919921875e-05, "kernel": 0.1025390625}, "out_proj": {"bias": 0.016357421875, "kernel": 0.267578125}, "q_proj": {"bias": 0.0057373046875, "kernel": 0.1005859375}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.220703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0223388671875, "kernel": 0.359375}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.341796875}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.043212890625, "scale": 0.034912109375}}, "17": {"attention": {"k_proj": {"bias": 4.57763671875e-05, "kernel": 0.0927734375}, "out_proj": {"bias": 0.0172119140625, "kernel": 0.23046875}, "q_proj": {"bias": 0.005889892578125, "kernel": 0.087890625}, "v_proj": {"bias": 0.0244140625, "kernel": 0.2177734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.390625}, "output_dense": {"bias": 0.01708984375, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.041259765625, "scale": 0.036376953125}, "layer_norm": {"bias": 0.0439453125, "scale": 0.0341796875}}, "18": {"attention": {"k_proj": {"bias": 0.000247955322265625, "kernel": 0.126953125}, "out_proj": {"bias": 0.017578125, "kernel": 0.369140625}, "q_proj": {"bias": 0.0076904296875, "kernel": 0.1337890625}, "v_proj": {"bias": 0.027587890625, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.44921875}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.04443359375, "scale": 0.03857421875}, "layer_norm": {"bias": 0.048583984375, "scale": 0.039794921875}}, "19": {"attention": {"k_proj": {"bias": 8.678436279296875e-05, "kernel": 0.140625}, "out_proj": {"bias": 0.017822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.009033203125, "kernel": 0.140625}, "v_proj": {"bias": 0.0286865234375, "kernel": 0.283203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.474609375}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.421875}}, "final_layer_norm": {"bias": 0.041748046875, "scale": 0.0380859375}, "layer_norm": {"bias": 0.052734375, "scale": 0.04052734375}}, "2": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.07421875}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.2060546875}, "q_proj": {"bias": 0.006195068359375, "kernel": 0.06982421875}, "v_proj": {"bias": 0.03173828125, "kernel": 0.181640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.390625}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.047119140625, "scale": 0.03173828125}, "layer_norm": {"bias": 0.0556640625, "scale": 0.07275390625}}, "20": {"attention": {"k_proj": {"bias": 2.110004425048828e-05, "kernel": 0.095703125}, "out_proj": {"bias": 0.0185546875, "kernel": 0.142578125}, "q_proj": {"bias": 0.005157470703125, "kernel": 0.0947265625}, "v_proj": {"bias": 0.0263671875, "kernel": 0.140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0250244140625, "kernel": 0.4765625}, "output_dense": {"bias": 0.018310546875, "kernel": 0.390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.046142578125, "scale": 0.038330078125}}, "21": {"attention": {"k_proj": {"bias": 4.00543212890625e-05, "kernel": 0.1259765625}, "out_proj": {"bias": 0.0189208984375, "kernel": 0.2216796875}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.12890625}, "v_proj": {"bias": 0.02734375, "kernel": 0.203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0267333984375, "kernel": 0.51953125}, "output_dense": {"bias": 0.0185546875, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.04541015625, "scale": 0.04736328125}, "layer_norm": {"bias": 0.044189453125, "scale": 0.054443359375}}, "22": {"attention": {"k_proj": {"bias": 3.3855438232421875e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.019775390625, "kernel": 0.240234375}, "q_proj": {"bias": 0.006011962890625, "kernel": 0.11279296875}, "v_proj": {"bias": 0.028076171875, "kernel": 0.21875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0269775390625, "kernel": 0.515625}, "output_dense": {"bias": 0.0194091796875, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.047119140625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.0458984375}}, "23": {"attention": {"k_proj": {"bias": 0.0001087188720703125, "kernel": 0.16015625}, "out_proj": {"bias": 0.0198974609375, "kernel": 0.443359375}, "q_proj": {"bias": 0.008544921875, "kernel": 0.1630859375}, "v_proj": {"bias": 0.03173828125, "kernel": 0.35546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0263671875, "kernel": 0.53125}, "output_dense": {"bias": 0.01953125, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.04638671875}, "layer_norm": {"bias": 0.05615234375, "scale": 0.056396484375}}, "24": {"attention": {"k_proj": {"bias": 6.246566772460938e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0191650390625, "kernel": 0.36328125}, "q_proj": {"bias": 0.00933837890625, "kernel": 0.18359375}, "v_proj": {"bias": 0.03271484375, "kernel": 0.328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02685546875, "kernel": 0.5390625}, "output_dense": {"bias": 0.01904296875, "kernel": 0.37890625}}, "final_layer_norm": {"bias": 0.04736328125, "scale": 0.04345703125}, "layer_norm": {"bias": 0.0625, "scale": 0.041015625}}, "25": {"attention": {"k_proj": {"bias": 6.079673767089844e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0196533203125, "kernel": 0.3125}, "q_proj": {"bias": 0.00860595703125, "kernel": 0.16015625}, "v_proj": {"bias": 0.03271484375, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.55859375}, "output_dense": {"bias": 0.01953125, "kernel": 0.375}}, "final_layer_norm": {"bias": 0.050537109375, "scale": 0.0478515625}, "layer_norm": {"bias": 0.06005859375, "scale": 0.06298828125}}, "26": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.01953125, "kernel": 0.29296875}, "q_proj": {"bias": 0.01025390625, "kernel": 0.177734375}, "v_proj": {"bias": 0.0341796875, "kernel": 0.296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.026611328125, "kernel": 0.51171875}, "output_dense": {"bias": 0.01904296875, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.060791015625, "scale": 0.06396484375}}, "27": {"attention": {"k_proj": {"bias": 0.00011396408081054688, "kernel": 0.2021484375}, "out_proj": {"bias": 0.01806640625, "kernel": 0.44921875}, "q_proj": {"bias": 0.01068115234375, "kernel": 0.2138671875}, "v_proj": {"bias": 0.03466796875, "kernel": 0.435546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.515625}, "output_dense": {"bias": 0.0181884765625, "kernel": 0.36328125}}, "final_layer_norm": {"bias": 0.05078125, "scale": 0.045654296875}, "layer_norm": {"bias": 0.06640625, "scale": 0.04931640625}}, "28": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.20703125}, "out_proj": {"bias": 0.0164794921875, "kernel": 0.392578125}, "q_proj": {"bias": 0.01165771484375, "kernel": 0.208984375}, "v_proj": {"bias": 0.031494140625, "kernel": 0.404296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.45703125}, "output_dense": {"bias": 0.016357421875, "kernel": 0.326171875}}, "final_layer_norm": {"bias": 0.04248046875, "scale": 0.044921875}, "layer_norm": {"bias": 0.0673828125, "scale": 0.08447265625}}, "29": {"attention": {"k_proj": {"bias": 9.918212890625e-05, "kernel": 0.267578125}, "out_proj": {"bias": 0.0157470703125, "kernel": 0.28515625}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.265625}, "v_proj": {"bias": 0.02978515625, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.494140625}, "output_dense": {"bias": 0.01531982421875, "kernel": 0.296875}}, "final_layer_norm": {"bias": 0.03955078125, "scale": 0.03515625}, "layer_norm": {"bias": 0.0654296875, "scale": 0.061279296875}}, "3": {"attention": {"k_proj": {"bias": 0.00012111663818359375, "kernel": 0.0986328125}, "out_proj": {"bias": 0.016845703125, "kernel": 0.314453125}, "q_proj": {"bias": 0.00726318359375, "kernel": 0.0888671875}, "v_proj": {"bias": 0.0283203125, "kernel": 0.2470703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0242919921875, "kernel": 0.3828125}, "output_dense": {"bias": 0.0150146484375, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.03125}, "layer_norm": {"bias": 0.0498046875, "scale": 0.0380859375}}, "30": {"attention": {"k_proj": {"bias": 0.0001220703125, "kernel": 0.13671875}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.328125}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.138671875}, "v_proj": {"bias": 0.029296875, "kernel": 0.3671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023681640625, "kernel": 0.51953125}, "output_dense": {"bias": 0.01446533203125, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03564453125}, "layer_norm": {"bias": 0.04931640625, "scale": 0.037109375}}, "31": {"attention": {"k_proj": {"bias": 0.00010347366333007812, "kernel": 0.14453125}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.29296875}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.134765625}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.314453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02392578125, "kernel": 0.51953125}, "output_dense": {"bias": 0.01385498046875, "kernel": 0.2578125}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.03662109375}, "layer_norm": {"bias": 0.039306640625, "scale": 0.0291748046875}}, "32": {"attention": {"k_proj": {"bias": 8.296966552734375e-05, "kernel": 0.15625}, "out_proj": {"bias": 0.01263427734375, "kernel": 0.28125}, "q_proj": {"bias": 0.0079345703125, "kernel": 0.1533203125}, "v_proj": {"bias": 0.0264892578125, "kernel": 0.4921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0216064453125, "kernel": 0.431640625}, "output_dense": {"bias": 0.01129150390625, "kernel": 0.212890625}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03271484375}, "layer_norm": {"bias": 0.046630859375, "scale": 0.05419921875}}, "33": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.166015625}, "out_proj": {"bias": 0.01092529296875, "kernel": 0.2275390625}, "q_proj": {"bias": 0.008544921875, "kernel": 0.166015625}, "v_proj": {"bias": 0.023193359375, "kernel": 0.34765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.390625}, "output_dense": {"bias": 0.00897216796875, "kernel": 0.1875}}, "final_layer_norm": {"bias": 0.04345703125, "scale": 0.0361328125}, "layer_norm": {"bias": 0.039794921875, "scale": 0.0498046875}}, "34": {"attention": {"k_proj": {"bias": 0.0002346038818359375, "kernel": 0.158203125}, "out_proj": {"bias": 0.0081787109375, "kernel": 0.181640625}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.14453125}, "v_proj": {"bias": 0.0177001953125, "kernel": 0.25390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01434326171875, "kernel": 0.291015625}, "output_dense": {"bias": 0.0072021484375, "kernel": 0.1748046875}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.025146484375}, "layer_norm": {"bias": 0.03369140625, "scale": 0.026611328125}}, "35": {"attention": {"k_proj": {"bias": 0.0001506805419921875, "kernel": 0.10791015625}, "out_proj": {"bias": 0.00640869140625, "kernel": 0.2109375}, "q_proj": {"bias": 0.004852294921875, "kernel": 0.10791015625}, "v_proj": {"bias": 0.01177978515625, "kernel": 0.21484375}}, "feed_forward": {"intermediate_dense": {"bias": 0.010498046875, "kernel": 0.2119140625}, "output_dense": {"bias": 0.005889892578125, "kernel": 0.15234375}}, "final_layer_norm": {"bias": 0.0206298828125, "scale": 0.0220947265625}, "layer_norm": {"bias": 0.024169921875, "scale": 0.02880859375}}, "36": {"attention": {"k_proj": {"bias": 4.410743713378906e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.005645751953125, "kernel": 0.1552734375}, "q_proj": {"bias": 0.00445556640625, "kernel": 0.095703125}, "v_proj": {"bias": 0.00946044921875, "kernel": 0.14453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0089111328125, "kernel": 0.177734375}, "output_dense": {"bias": 0.0050048828125, "kernel": 0.111328125}}, "final_layer_norm": {"bias": 0.017578125, "scale": 0.01513671875}, "layer_norm": {"bias": 0.0191650390625, "scale": 0.01806640625}}, "37": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.0849609375}, "out_proj": {"bias": 0.004913330078125, "kernel": 0.11474609375}, "q_proj": {"bias": 0.00390625, "kernel": 0.0830078125}, "v_proj": {"bias": 0.00897216796875, "kernel": 0.1318359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.00823974609375, "kernel": 0.16796875}, "output_dense": {"bias": 0.004241943359375, "kernel": 0.09716796875}}, "final_layer_norm": {"bias": 0.015869140625, "scale": 0.01434326171875}, "layer_norm": {"bias": 0.019287109375, "scale": 0.015869140625}}, "38": {"attention": {"k_proj": {"bias": 5.650520324707031e-05, "kernel": 0.09130859375}, "out_proj": {"bias": 0.0040283203125, "kernel": 0.11865234375}, "q_proj": {"bias": 0.00396728515625, "kernel": 0.08642578125}, "v_proj": {"bias": 0.007354736328125, "kernel": 0.1279296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0072021484375, "kernel": 0.150390625}, "output_dense": {"bias": 0.0034637451171875, "kernel": 0.09423828125}}, "final_layer_norm": {"bias": 0.0152587890625, "scale": 0.0146484375}, "layer_norm": {"bias": 0.0162353515625, "scale": 0.0135498046875}}, "39": {"attention": {"k_proj": {"bias": 5.316734313964844e-05, "kernel": 0.09619140625}, "out_proj": {"bias": 0.0030975341796875, "kernel": 0.09619140625}, "q_proj": {"bias": 0.00408935546875, "kernel": 0.0908203125}, "v_proj": {"bias": 0.006011962890625, "kernel": 0.10986328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.005401611328125, "kernel": 0.12109375}, "output_dense": {"bias": 0.0025634765625, "kernel": 0.08642578125}}, "final_layer_norm": {"bias": 0.01202392578125, "scale": 0.01226806640625}, "layer_norm": {"bias": 0.0150146484375, "scale": 0.01556396484375}}, "4": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.10498046875}, "out_proj": {"bias": 0.015869140625, "kernel": 0.361328125}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1005859375}, "v_proj": {"bias": 0.026123046875, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.36328125}, "output_dense": {"bias": 0.014404296875, "kernel": 0.29296875}}, "final_layer_norm": {"bias": 0.042724609375, "scale": 0.034423828125}, "layer_norm": {"bias": 0.0478515625, "scale": 0.060546875}}, "40": {"attention": {"k_proj": {"bias": 5.269050598144531e-05, "kernel": 0.046875}, "out_proj": {"bias": 0.0025787353515625, "kernel": 0.080078125}, "q_proj": {"bias": 0.0020294189453125, "kernel": 0.0458984375}, "v_proj": {"bias": 0.004150390625, "kernel": 0.07080078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.004302978515625, "kernel": 0.09326171875}, "output_dense": {"bias": 0.0023040771484375, "kernel": 0.060791015625}}, "final_layer_norm": {"bias": 0.0087890625, "scale": 0.011474609375}, "layer_norm": {"bias": 0.00823974609375, "scale": 0.007781982421875}}, "41": {"attention": {"k_proj": {"bias": 4.0531158447265625e-05, "kernel": 0.0673828125}, "out_proj": {"bias": 0.002044677734375, "kernel": 0.087890625}, "q_proj": {"bias": 0.0025787353515625, "kernel": 0.0634765625}, "v_proj": {"bias": 0.00439453125, "kernel": 0.1044921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0035858154296875, "kernel": 0.091796875}, "output_dense": {"bias": 0.00168609619140625, "kernel": 0.064453125}}, "final_layer_norm": {"bias": 0.00921630859375, "scale": 0.0106201171875}, "layer_norm": {"bias": 0.010986328125, "scale": 0.0128173828125}}, "42": {"attention": {"k_proj": {"bias": 1.1801719665527344e-05, "kernel": 0.02099609375}, "out_proj": {"bias": 0.001678466796875, "kernel": 0.048828125}, "q_proj": {"bias": 0.0009307861328125, "kernel": 0.02197265625}, "v_proj": {"bias": 0.002349853515625, "kernel": 0.0478515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.002685546875, "kernel": 0.07421875}, "output_dense": {"bias": 0.0014801025390625, "kernel": 0.053955078125}}, "final_layer_norm": {"bias": 0.0059814453125, "scale": 0.00885009765625}, "layer_norm": {"bias": 0.0045166015625, "scale": 0.00439453125}}, "43": {"attention": {"k_proj": {"bias": 8.046627044677734e-06, "kernel": 0.01806640625}, "out_proj": {"bias": 0.0015106201171875, "kernel": 0.035888671875}, "q_proj": {"bias": 0.00095367431640625, "kernel": 0.0208740234375}, "v_proj": {"bias": 0.00171661376953125, "kernel": 0.031005859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.002777099609375, "kernel": 0.083984375}, "output_dense": {"bias": 0.00125885009765625, "kernel": 0.052734375}}, "final_layer_norm": {"bias": 0.00689697265625, "scale": 0.007110595703125}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.005706787109375}}, "44": {"attention": {"k_proj": {"bias": 1.3113021850585938e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.001312255859375, "kernel": 0.033447265625}, "q_proj": {"bias": 0.000946044921875, "kernel": 0.021484375}, "v_proj": {"bias": 0.0017547607421875, "kernel": 0.03515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0023956298828125, "kernel": 0.0771484375}, "output_dense": {"bias": 0.001129150390625, "kernel": 0.052001953125}}, "final_layer_norm": {"bias": 0.005706787109375, "scale": 0.005859375}, "layer_norm": {"bias": 0.0042724609375, "scale": 0.004730224609375}}, "45": {"attention": {"k_proj": {"bias": 1.4424324035644531e-05, "kernel": 0.01239013671875}, "out_proj": {"bias": 0.00110626220703125, "kernel": 0.0267333984375}, "q_proj": {"bias": 0.00136566162109375, "kernel": 0.029541015625}, "v_proj": {"bias": 0.00142669677734375, "kernel": 0.027587890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0019989013671875, "kernel": 0.06396484375}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.05615234375}}, "final_layer_norm": {"bias": 0.006072998046875, "scale": 0.006591796875}, "layer_norm": {"bias": 0.004791259765625, "scale": 0.00494384765625}}, "46": {"attention": {"k_proj": {"bias": 5.745887756347656e-05, "kernel": 0.006439208984375}, "out_proj": {"bias": 0.000888824462890625, "kernel": 0.0289306640625}, "q_proj": {"bias": 0.000591278076171875, "kernel": 0.011962890625}, "v_proj": {"bias": 0.0010986328125, "kernel": 0.0233154296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00141143798828125, "kernel": 0.03955078125}, "output_dense": {"bias": 0.000873565673828125, "kernel": 0.0478515625}}, "final_layer_norm": {"bias": 0.00433349609375, "scale": 0.00433349609375}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.003814697265625}}, "47": {"attention": {"k_proj": {"bias": 0.00011301040649414062, "kernel": 0.003997802734375}, "out_proj": {"bias": 0.000896453857421875, "kernel": 0.06640625}, "q_proj": {"bias": 0.00014591217041015625, "kernel": 0.00286865234375}, "v_proj": {"bias": 0.00118255615234375, "kernel": 0.0230712890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0010528564453125, "kernel": 0.0252685546875}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.1787109375}}, "final_layer_norm": {"bias": 0.005950927734375, "scale": 0.00677490234375}, "layer_norm": {"bias": 0.005859375, "scale": 0.005767822265625}}, "5": {"attention": {"k_proj": {"bias": 6.4849853515625e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0159912109375, "kernel": 0.203125}, "q_proj": {"bias": 0.007598876953125, "kernel": 0.12158203125}, "v_proj": {"bias": 0.02685546875, "kernel": 0.1953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.33984375}, "output_dense": {"bias": 0.0147705078125, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.037841796875}, "layer_norm": {"bias": 0.052734375, "scale": 0.04736328125}}, "6": {"attention": {"k_proj": {"bias": 7.152557373046875e-05, "kernel": 0.1318359375}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.349609375}, "q_proj": {"bias": 0.00823974609375, "kernel": 0.119140625}, "v_proj": {"bias": 0.02685546875, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.341796875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.058349609375}}, "7": {"attention": {"k_proj": {"bias": 7.62939453125e-05, "kernel": 0.1328125}, "out_proj": {"bias": 0.0150146484375, "kernel": 0.349609375}, "q_proj": {"bias": 0.00921630859375, "kernel": 0.126953125}, "v_proj": {"bias": 0.0255126953125, "kernel": 0.30859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.33984375}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.038818359375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.05126953125, "scale": 0.050048828125}}, "8": {"attention": {"k_proj": {"bias": 7.915496826171875e-05, "kernel": 0.12060546875}, "out_proj": {"bias": 0.01507568359375, "kernel": 0.302734375}, "q_proj": {"bias": 0.007568359375, "kernel": 0.11474609375}, "v_proj": {"bias": 0.0262451171875, "kernel": 0.27734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.361328125}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.033447265625}, "layer_norm": {"bias": 0.0498046875, "scale": 0.06103515625}}, "9": {"attention": {"k_proj": {"bias": 0.00011777877807617188, "kernel": 0.1513671875}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.416015625}, "q_proj": {"bias": 0.00830078125, "kernel": 0.1376953125}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.40234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0213623046875, "kernel": 0.3515625}, "output_dense": {"bias": 0.0135498046875, "kernel": 0.279296875}}, "final_layer_norm": {"bias": 0.037109375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04443359375, "scale": 0.04833984375}}}, "pos_conv_embed": {"conv": {"bias": 0.034912109375, "weight_g": 0.044189453125, "weight_v": 0.287109375}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.1337890625, "scale": 0.1611328125}, "projection": {"bias": 0.0556640625, "kernel": 1.0546875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.7824921607971191, "kernel": 55.72966766357422}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 59.6768798828125, "scale": 74.17054748535156}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.37033993005752563, "kernel": 27.536663055419922}, "out_proj": {"bias": 1.6469175815582275, "kernel": 26.147050857543945}, "q_proj": {"bias": 1.5330281257629395, "kernel": 27.813282012939453}, "v_proj": {"bias": 0.44783300161361694, "kernel": 26.55841064453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9835489988327026, "kernel": 100.66567993164062}, "output_dense": {"bias": 1.116748571395874, "kernel": 96.76679992675781}}, "final_layer_norm": {"bias": 1.335214376449585, "scale": 19.85782241821289}, "layer_norm": {"bias": 2.923041343688965, "scale": 15.398418426513672}}, "1": {"attention": {"k_proj": {"bias": 0.3767347037792206, "kernel": 41.013240814208984}, "out_proj": {"bias": 1.3653483390808105, "kernel": 43.371070861816406}, "q_proj": {"bias": 3.0925614833831787, "kernel": 41.05661392211914}, "v_proj": {"bias": 0.2924947738647461, "kernel": 41.61189270019531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9882662296295166, "kernel": 98.7442626953125}, "output_dense": {"bias": 0.8527815341949463, "kernel": 87.83541870117188}}, "final_layer_norm": {"bias": 1.3605934381484985, "scale": 19.084806442260742}, "layer_norm": {"bias": 1.9318764209747314, "scale": 17.761367797851562}}, "10": {"attention": {"k_proj": {"bias": 0.4123449921607971, "kernel": 49.44670486450195}, "out_proj": {"bias": 1.3130683898925781, "kernel": 52.27025604248047}, "q_proj": {"bias": 2.48445200920105, "kernel": 49.528873443603516}, "v_proj": {"bias": 0.3416975736618042, "kernel": 52.344085693359375}}, "feed_forward": {"intermediate_dense": {"bias": 1.974550485610962, "kernel": 102.70410919189453}, "output_dense": {"bias": 0.5955485105514526, "kernel": 95.81275939941406}}, "final_layer_norm": {"bias": 2.3762400150299072, "scale": 20.81279754638672}, "layer_norm": {"bias": 1.806241512298584, "scale": 21.429487228393555}}, "11": {"attention": {"k_proj": {"bias": 0.4460787773132324, "kernel": 49.37338638305664}, "out_proj": {"bias": 1.1512463092803955, "kernel": 51.949554443359375}, "q_proj": {"bias": 2.5446064472198486, "kernel": 49.20353698730469}, "v_proj": {"bias": 0.40995872020721436, "kernel": 52.2607536315918}}, "feed_forward": {"intermediate_dense": {"bias": 2.019528388977051, "kernel": 103.56025695800781}, "output_dense": {"bias": 0.5711302757263184, "kernel": 97.53776550292969}}, "final_layer_norm": {"bias": 2.3660812377929688, "scale": 20.919677734375}, "layer_norm": {"bias": 1.7802445888519287, "scale": 22.01519203186035}}, "12": {"attention": {"k_proj": {"bias": 0.4312320351600647, "kernel": 50.07032775878906}, "out_proj": {"bias": 1.126122236251831, "kernel": 51.988826751708984}, "q_proj": {"bias": 2.4090728759765625, "kernel": 49.92729949951172}, "v_proj": {"bias": 0.4024103879928589, "kernel": 52.296756744384766}}, "feed_forward": {"intermediate_dense": {"bias": 2.0548508167266846, "kernel": 104.54823303222656}, "output_dense": {"bias": 0.5548778772354126, "kernel": 99.2693099975586}}, "final_layer_norm": {"bias": 2.2933573722839355, "scale": 20.85626983642578}, "layer_norm": {"bias": 1.8587299585342407, "scale": 22.473487854003906}}, "13": {"attention": {"k_proj": {"bias": 0.4430793821811676, "kernel": 51.76431655883789}, "out_proj": {"bias": 1.1271920204162598, "kernel": 51.86852264404297}, "q_proj": {"bias": 2.359200954437256, "kernel": 51.75225830078125}, "v_proj": {"bias": 0.3906242251396179, "kernel": 51.92781066894531}}, "feed_forward": {"intermediate_dense": {"bias": 2.0941619873046875, "kernel": 105.30806732177734}, "output_dense": {"bias": 0.5719542503356934, "kernel": 99.8712158203125}}, "final_layer_norm": {"bias": 2.2314066886901855, "scale": 21.027400970458984}, "layer_norm": {"bias": 1.9997800588607788, "scale": 22.84510040283203}}, "14": {"attention": {"k_proj": {"bias": 0.43604815006256104, "kernel": 51.92181396484375}, "out_proj": {"bias": 1.2681217193603516, "kernel": 49.762760162353516}, "q_proj": {"bias": 2.4942922592163086, "kernel": 52.049808502197266}, "v_proj": {"bias": 0.36820662021636963, "kernel": 49.26283264160156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1335830688476562, "kernel": 105.94457244873047}, "output_dense": {"bias": 0.6063626408576965, "kernel": 101.24859619140625}}, "final_layer_norm": {"bias": 2.2754664421081543, "scale": 21.145992279052734}, "layer_norm": {"bias": 2.1295526027679443, "scale": 22.584672927856445}}, "15": {"attention": {"k_proj": {"bias": 0.45931702852249146, "kernel": 51.93058776855469}, "out_proj": {"bias": 1.3753983974456787, "kernel": 50.94449234008789}, "q_proj": {"bias": 2.5871663093566895, "kernel": 52.099769592285156}, "v_proj": {"bias": 0.459650456905365, "kernel": 50.5831298828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.132938861846924, "kernel": 105.57443237304688}, "output_dense": {"bias": 0.7701732516288757, "kernel": 101.94094848632812}}, "final_layer_norm": {"bias": 2.327320098876953, "scale": 21.192947387695312}, "layer_norm": {"bias": 2.4148712158203125, "scale": 23.526634216308594}}, "16": {"attention": {"k_proj": {"bias": 0.4008745551109314, "kernel": 51.772621154785156}, "out_proj": {"bias": 1.27531099319458, "kernel": 50.134521484375}, "q_proj": {"bias": 2.667466163635254, "kernel": 51.75814437866211}, "v_proj": {"bias": 0.3768249750137329, "kernel": 49.78093719482422}}, "feed_forward": {"intermediate_dense": {"bias": 2.1094985008239746, "kernel": 106.10227966308594}, "output_dense": {"bias": 0.7860437631607056, "kernel": 102.66590881347656}}, "final_layer_norm": {"bias": 2.337951421737671, "scale": 21.583194732666016}, "layer_norm": {"bias": 2.283249855041504, "scale": 22.168060302734375}}, "17": {"attention": {"k_proj": {"bias": 0.3966267704963684, "kernel": 51.728885650634766}, "out_proj": {"bias": 1.2151354551315308, "kernel": 49.4556884765625}, "q_proj": {"bias": 2.714320182800293, "kernel": 51.81880187988281}, "v_proj": {"bias": 0.42661017179489136, "kernel": 49.11927032470703}}, "feed_forward": {"intermediate_dense": {"bias": 2.1101765632629395, "kernel": 107.14872741699219}, "output_dense": {"bias": 0.8218655586242676, "kernel": 103.06423950195312}}, "final_layer_norm": {"bias": 2.383938789367676, "scale": 22.070323944091797}, "layer_norm": {"bias": 2.222898483276367, "scale": 21.219982147216797}}, "18": {"attention": {"k_proj": {"bias": 0.4409676194190979, "kernel": 52.41611099243164}, "out_proj": {"bias": 1.3447906970977783, "kernel": 50.491905212402344}, "q_proj": {"bias": 2.614685535430908, "kernel": 52.796600341796875}, "v_proj": {"bias": 0.4518332779407501, "kernel": 50.001895904541016}}, "feed_forward": {"intermediate_dense": {"bias": 2.144195556640625, "kernel": 107.41338348388672}, "output_dense": {"bias": 0.9481453895568848, "kernel": 104.72514343261719}}, "final_layer_norm": {"bias": 2.5390124320983887, "scale": 22.15178680419922}, "layer_norm": {"bias": 2.424910068511963, "scale": 23.585906982421875}}, "19": {"attention": {"k_proj": {"bias": 0.38193291425704956, "kernel": 51.511146545410156}, "out_proj": {"bias": 1.3303101062774658, "kernel": 50.10035705566406}, "q_proj": {"bias": 2.930327892303467, "kernel": 51.865638732910156}, "v_proj": {"bias": 0.4086824655532837, "kernel": 49.38078308105469}}, "feed_forward": {"intermediate_dense": {"bias": 2.1912901401519775, "kernel": 107.95254516601562}, "output_dense": {"bias": 1.0248571634292603, "kernel": 105.65098571777344}}, "final_layer_norm": {"bias": 2.4923481941223145, "scale": 22.505674362182617}, "layer_norm": {"bias": 2.2888314723968506, "scale": 22.31826400756836}}, "2": {"attention": {"k_proj": {"bias": 0.454792320728302, "kernel": 47.77275085449219}, "out_proj": {"bias": 1.256988525390625, "kernel": 45.969764709472656}, "q_proj": {"bias": 3.2510807514190674, "kernel": 47.61664581298828}, "v_proj": {"bias": 0.339598685503006, "kernel": 45.72273254394531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9737317562103271, "kernel": 103.32754516601562}, "output_dense": {"bias": 0.7398276329040527, "kernel": 91.11263275146484}}, "final_layer_norm": {"bias": 1.5421981811523438, "scale": 21.561111450195312}, "layer_norm": {"bias": 1.7081801891326904, "scale": 20.852447509765625}}, "20": {"attention": {"k_proj": {"bias": 0.4067543148994446, "kernel": 51.605438232421875}, "out_proj": {"bias": 1.359946370124817, "kernel": 49.45553207397461}, "q_proj": {"bias": 2.8498687744140625, "kernel": 52.224571228027344}, "v_proj": {"bias": 0.36227869987487793, "kernel": 48.43864822387695}}, "feed_forward": {"intermediate_dense": {"bias": 2.1725549697875977, "kernel": 109.17405700683594}, "output_dense": {"bias": 1.1388803720474243, "kernel": 106.40528106689453}}, "final_layer_norm": {"bias": 2.435314655303955, "scale": 23.4317626953125}, "layer_norm": {"bias": 2.231672525405884, "scale": 22.230525970458984}}, "21": {"attention": {"k_proj": {"bias": 0.4161534905433655, "kernel": 51.942527770996094}, "out_proj": {"bias": 1.403618335723877, "kernel": 49.51059341430664}, "q_proj": {"bias": 2.7690629959106445, "kernel": 52.67078399658203}, "v_proj": {"bias": 0.41060006618499756, "kernel": 48.64883041381836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2174296379089355, "kernel": 109.5155029296875}, "output_dense": {"bias": 1.253208041191101, "kernel": 106.88243865966797}}, "final_layer_norm": {"bias": 2.4632763862609863, "scale": 23.175764083862305}, "layer_norm": {"bias": 2.2785892486572266, "scale": 22.234222412109375}}, "22": {"attention": {"k_proj": {"bias": 0.45357397198677063, "kernel": 52.54576110839844}, "out_proj": {"bias": 1.349219560623169, "kernel": 49.533172607421875}, "q_proj": {"bias": 2.8105549812316895, "kernel": 52.86981201171875}, "v_proj": {"bias": 0.3973655700683594, "kernel": 49.33363342285156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1619315147399902, "kernel": 109.95498657226562}, "output_dense": {"bias": 1.3076066970825195, "kernel": 106.3852310180664}}, "final_layer_norm": {"bias": 2.3642821311950684, "scale": 22.684059143066406}, "layer_norm": {"bias": 2.3316237926483154, "scale": 21.545879364013672}}, "23": {"attention": {"k_proj": {"bias": 0.4928613007068634, "kernel": 53.47669219970703}, "out_proj": {"bias": 1.564335823059082, "kernel": 50.98707580566406}, "q_proj": {"bias": 2.7065773010253906, "kernel": 53.582611083984375}, "v_proj": {"bias": 0.5810648202896118, "kernel": 51.54853820800781}}, "feed_forward": {"intermediate_dense": {"bias": 2.131969690322876, "kernel": 109.86410522460938}, "output_dense": {"bias": 1.2769315242767334, "kernel": 107.37890625}}, "final_layer_norm": {"bias": 2.767916679382324, "scale": 22.887813568115234}, "layer_norm": {"bias": 2.824352264404297, "scale": 23.373172760009766}}, "24": {"attention": {"k_proj": {"bias": 0.46056002378463745, "kernel": 52.424072265625}, "out_proj": {"bias": 1.6070430278778076, "kernel": 52.50334167480469}, "q_proj": {"bias": 2.828113079071045, "kernel": 52.40515899658203}, "v_proj": {"bias": 0.5424190163612366, "kernel": 52.51116180419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.2367913722991943, "kernel": 109.35035705566406}, "output_dense": {"bias": 1.3016372919082642, "kernel": 110.30095672607422}}, "final_layer_norm": {"bias": 2.83841872215271, "scale": 22.964658737182617}, "layer_norm": {"bias": 2.56215763092041, "scale": 22.983924865722656}}, "25": {"attention": {"k_proj": {"bias": 0.42509031295776367, "kernel": 52.730464935302734}, "out_proj": {"bias": 1.363797664642334, "kernel": 50.5806884765625}, "q_proj": {"bias": 2.9342763423919678, "kernel": 52.548744201660156}, "v_proj": {"bias": 0.6404213309288025, "kernel": 51.0885009765625}}, "feed_forward": {"intermediate_dense": {"bias": 2.1367578506469727, "kernel": 109.70021057128906}, "output_dense": {"bias": 1.1017413139343262, "kernel": 110.27072143554688}}, "final_layer_norm": {"bias": 2.5763301849365234, "scale": 23.494670867919922}, "layer_norm": {"bias": 2.683134078979492, "scale": 21.88357925415039}}, "26": {"attention": {"k_proj": {"bias": 0.4836847186088562, "kernel": 53.01764678955078}, "out_proj": {"bias": 1.2433912754058838, "kernel": 51.37077331542969}, "q_proj": {"bias": 2.943906784057617, "kernel": 52.80891036987305}, "v_proj": {"bias": 0.5064959526062012, "kernel": 52.004638671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.2763516902923584, "kernel": 109.44652557373047}, "output_dense": {"bias": 1.0912110805511475, "kernel": 107.40899658203125}}, "final_layer_norm": {"bias": 2.1937994956970215, "scale": 22.433353424072266}, "layer_norm": {"bias": 2.497119903564453, "scale": 22.19057273864746}}, "27": {"attention": {"k_proj": {"bias": 0.5808594226837158, "kernel": 53.76898956298828}, "out_proj": {"bias": 1.5447406768798828, "kernel": 52.95805358886719}, "q_proj": {"bias": 2.703345775604248, "kernel": 53.69578552246094}, "v_proj": {"bias": 0.6748642325401306, "kernel": 53.388118743896484}}, "feed_forward": {"intermediate_dense": {"bias": 2.404933452606201, "kernel": 107.8713150024414}, "output_dense": {"bias": 0.9485896825790405, "kernel": 107.17198181152344}}, "final_layer_norm": {"bias": 2.5252954959869385, "scale": 21.88959503173828}, "layer_norm": {"bias": 2.6147172451019287, "scale": 23.32440948486328}}, "28": {"attention": {"k_proj": {"bias": 0.5901432037353516, "kernel": 54.482521057128906}, "out_proj": {"bias": 1.5367379188537598, "kernel": 53.31493377685547}, "q_proj": {"bias": 2.9472482204437256, "kernel": 54.1741943359375}, "v_proj": {"bias": 0.5131911039352417, "kernel": 53.759761810302734}}, "feed_forward": {"intermediate_dense": {"bias": 2.3475265502929688, "kernel": 107.87416076660156}, "output_dense": {"bias": 0.8224154710769653, "kernel": 109.1680908203125}}, "final_layer_norm": {"bias": 2.425306797027588, "scale": 22.337677001953125}, "layer_norm": {"bias": 2.0914058685302734, "scale": 23.993711471557617}}, "29": {"attention": {"k_proj": {"bias": 0.46781182289123535, "kernel": 51.12034606933594}, "out_proj": {"bias": 1.5021522045135498, "kernel": 55.685630798339844}, "q_proj": {"bias": 2.809702157974243, "kernel": 51.00274658203125}, "v_proj": {"bias": 0.4760415554046631, "kernel": 55.703304290771484}}, "feed_forward": {"intermediate_dense": {"bias": 2.297222137451172, "kernel": 108.01033020019531}, "output_dense": {"bias": 0.9597339630126953, "kernel": 113.12825012207031}}, "final_layer_norm": {"bias": 2.5980498790740967, "scale": 23.459980010986328}, "layer_norm": {"bias": 2.245180130004883, "scale": 25.39927864074707}}, "3": {"attention": {"k_proj": {"bias": 0.45006245374679565, "kernel": 52.03215789794922}, "out_proj": {"bias": 1.4254932403564453, "kernel": 48.60858917236328}, "q_proj": {"bias": 2.8560738563537598, "kernel": 52.312644958496094}, "v_proj": {"bias": 0.3246268630027771, "kernel": 48.768699645996094}}, "feed_forward": {"intermediate_dense": {"bias": 1.9663825035095215, "kernel": 104.83622741699219}, "output_dense": {"bias": 0.6984099745750427, "kernel": 94.07957458496094}}, "final_layer_norm": {"bias": 1.8095453977584839, "scale": 21.664737701416016}, "layer_norm": {"bias": 1.9017157554626465, "scale": 22.739452362060547}}, "30": {"attention": {"k_proj": {"bias": 0.5024805665016174, "kernel": 52.825706481933594}, "out_proj": {"bias": 1.3023658990859985, "kernel": 52.053871154785156}, "q_proj": {"bias": 2.907101631164551, "kernel": 52.91836166381836}, "v_proj": {"bias": 0.49308842420578003, "kernel": 52.49382019042969}}, "feed_forward": {"intermediate_dense": {"bias": 2.2399911880493164, "kernel": 108.17861938476562}, "output_dense": {"bias": 0.9140658378601074, "kernel": 112.09104919433594}}, "final_layer_norm": {"bias": 2.4926414489746094, "scale": 24.492368698120117}, "layer_norm": {"bias": 2.316732168197632, "scale": 24.931156158447266}}, "31": {"attention": {"k_proj": {"bias": 0.5412741899490356, "kernel": 51.240806579589844}, "out_proj": {"bias": 1.2333163022994995, "kernel": 52.19988250732422}, "q_proj": {"bias": 2.6581294536590576, "kernel": 51.346168518066406}, "v_proj": {"bias": 0.5469827651977539, "kernel": 52.432586669921875}}, "feed_forward": {"intermediate_dense": {"bias": 2.3097352981567383, "kernel": 106.72758483886719}, "output_dense": {"bias": 1.0891624689102173, "kernel": 109.24717712402344}}, "final_layer_norm": {"bias": 2.2962756156921387, "scale": 24.31252670288086}, "layer_norm": {"bias": 2.3430848121643066, "scale": 24.590187072753906}}, "32": {"attention": {"k_proj": {"bias": 0.4704548716545105, "kernel": 50.39933776855469}, "out_proj": {"bias": 1.2453913688659668, "kernel": 51.57465362548828}, "q_proj": {"bias": 2.8450098037719727, "kernel": 50.34847640991211}, "v_proj": {"bias": 0.419519305229187, "kernel": 51.972320556640625}}, "feed_forward": {"intermediate_dense": {"bias": 2.2569355964660645, "kernel": 105.33018493652344}, "output_dense": {"bias": 1.146787166595459, "kernel": 108.35574340820312}}, "final_layer_norm": {"bias": 2.31538724899292, "scale": 24.518985748291016}, "layer_norm": {"bias": 2.417579174041748, "scale": 24.991830825805664}}, "33": {"attention": {"k_proj": {"bias": 0.48390907049179077, "kernel": 50.28266143798828}, "out_proj": {"bias": 1.280959129333496, "kernel": 51.30557632446289}, "q_proj": {"bias": 2.998173713684082, "kernel": 50.253868103027344}, "v_proj": {"bias": 0.4416005611419678, "kernel": 51.71035385131836}}, "feed_forward": {"intermediate_dense": {"bias": 2.279946804046631, "kernel": 103.67684173583984}, "output_dense": {"bias": 1.1754591464996338, "kernel": 106.81501007080078}}, "final_layer_norm": {"bias": 2.257889747619629, "scale": 24.207740783691406}, "layer_norm": {"bias": 2.5865607261657715, "scale": 25.067882537841797}}, "34": {"attention": {"k_proj": {"bias": 0.45390456914901733, "kernel": 49.26523208618164}, "out_proj": {"bias": 1.5265402793884277, "kernel": 52.46200942993164}, "q_proj": {"bias": 2.913527011871338, "kernel": 49.268951416015625}, "v_proj": {"bias": 0.4023621678352356, "kernel": 52.534793853759766}}, "feed_forward": {"intermediate_dense": {"bias": 2.3745017051696777, "kernel": 102.21138000488281}, "output_dense": {"bias": 1.1227428913116455, "kernel": 105.72161102294922}}, "final_layer_norm": {"bias": 2.20273494720459, "scale": 23.640857696533203}, "layer_norm": {"bias": 2.615731716156006, "scale": 25.498104095458984}}, "35": {"attention": {"k_proj": {"bias": 0.5336894989013672, "kernel": 51.04902648925781}, "out_proj": {"bias": 1.4900906085968018, "kernel": 51.15006637573242}, "q_proj": {"bias": 2.573650360107422, "kernel": 51.323951721191406}, "v_proj": {"bias": 0.4889468550682068, "kernel": 51.176795959472656}}, "feed_forward": {"intermediate_dense": {"bias": 2.502547264099121, "kernel": 100.7535400390625}, "output_dense": {"bias": 1.0254027843475342, "kernel": 104.25007629394531}}, "final_layer_norm": {"bias": 2.295243263244629, "scale": 23.61981964111328}, "layer_norm": {"bias": 2.499863862991333, "scale": 26.093618392944336}}, "36": {"attention": {"k_proj": {"bias": 0.4473738670349121, "kernel": 48.319740295410156}, "out_proj": {"bias": 1.5177178382873535, "kernel": 52.26478958129883}, "q_proj": {"bias": 2.6124308109283447, "kernel": 48.236427307128906}, "v_proj": {"bias": 0.39316776394844055, "kernel": 52.66499328613281}}, "feed_forward": {"intermediate_dense": {"bias": 2.3669564723968506, "kernel": 99.60517120361328}, "output_dense": {"bias": 1.0260361433029175, "kernel": 103.68067932128906}}, "final_layer_norm": {"bias": 2.044203519821167, "scale": 24.13540267944336}, "layer_norm": {"bias": 2.2894434928894043, "scale": 25.63806915283203}}, "37": {"attention": {"k_proj": {"bias": 0.6240901350975037, "kernel": 47.300960540771484}, "out_proj": {"bias": 1.7604304552078247, "kernel": 52.189971923828125}, "q_proj": {"bias": 2.3819239139556885, "kernel": 47.316253662109375}, "v_proj": {"bias": 0.38518577814102173, "kernel": 52.32656478881836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2694783210754395, "kernel": 98.56558227539062}, "output_dense": {"bias": 1.0087021589279175, "kernel": 103.13001251220703}}, "final_layer_norm": {"bias": 1.7903382778167725, "scale": 24.500085830688477}, "layer_norm": {"bias": 2.238807439804077, "scale": 25.563133239746094}}, "38": {"attention": {"k_proj": {"bias": 0.722891092300415, "kernel": 45.45281219482422}, "out_proj": {"bias": 1.4463956356048584, "kernel": 51.495086669921875}, "q_proj": {"bias": 2.2622861862182617, "kernel": 45.454437255859375}, "v_proj": {"bias": 0.42936205863952637, "kernel": 51.565425872802734}}, "feed_forward": {"intermediate_dense": {"bias": 2.2012171745300293, "kernel": 96.4342041015625}, "output_dense": {"bias": 0.9817801713943481, "kernel": 101.30387115478516}}, "final_layer_norm": {"bias": 1.7825044393539429, "scale": 25.214679718017578}, "layer_norm": {"bias": 2.4084482192993164, "scale": 26.425636291503906}}, "39": {"attention": {"k_proj": {"bias": 0.7209377884864807, "kernel": 45.24742889404297}, "out_proj": {"bias": 1.7097458839416504, "kernel": 51.329002380371094}, "q_proj": {"bias": 2.1152358055114746, "kernel": 45.53700637817383}, "v_proj": {"bias": 0.4246286153793335, "kernel": 51.27728271484375}}, "feed_forward": {"intermediate_dense": {"bias": 2.177624225616455, "kernel": 94.23902893066406}, "output_dense": {"bias": 1.0458879470825195, "kernel": 101.17718505859375}}, "final_layer_norm": {"bias": 1.753927230834961, "scale": 25.785865783691406}, "layer_norm": {"bias": 2.33198881149292, "scale": 26.942312240600586}}, "4": {"attention": {"k_proj": {"bias": 0.44548165798187256, "kernel": 54.6234130859375}, "out_proj": {"bias": 1.652343988418579, "kernel": 50.16497039794922}, "q_proj": {"bias": 2.615248918533325, "kernel": 54.93098068237305}, "v_proj": {"bias": 0.34892427921295166, "kernel": 50.320098876953125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531786441802979, "kernel": 104.50968933105469}, "output_dense": {"bias": 0.8575068712234497, "kernel": 95.54541015625}}, "final_layer_norm": {"bias": 1.9920002222061157, "scale": 21.200180053710938}, "layer_norm": {"bias": 2.054612159729004, "scale": 23.620338439941406}}, "40": {"attention": {"k_proj": {"bias": 0.6590453386306763, "kernel": 44.198089599609375}, "out_proj": {"bias": 1.6252505779266357, "kernel": 49.55699920654297}, "q_proj": {"bias": 1.9674756526947021, "kernel": 44.89208221435547}, "v_proj": {"bias": 0.4587768614292145, "kernel": 49.23614501953125}}, "feed_forward": {"intermediate_dense": {"bias": 2.0333969593048096, "kernel": 92.16896057128906}, "output_dense": {"bias": 1.087776780128479, "kernel": 98.40738677978516}}, "final_layer_norm": {"bias": 1.7852704524993896, "scale": 25.04292106628418}, "layer_norm": {"bias": 2.2756104469299316, "scale": 26.40799903869629}}, "41": {"attention": {"k_proj": {"bias": 1.7133712768554688, "kernel": 41.96858596801758}, "out_proj": {"bias": 1.3790823221206665, "kernel": 51.25593566894531}, "q_proj": {"bias": 1.71382737159729, "kernel": 42.56317901611328}, "v_proj": {"bias": 0.4695759415626526, "kernel": 50.369529724121094}}, "feed_forward": {"intermediate_dense": {"bias": 2.110393524169922, "kernel": 88.92567443847656}, "output_dense": {"bias": 1.1446669101715088, "kernel": 97.37409973144531}}, "final_layer_norm": {"bias": 2.23917293548584, "scale": 28.507984161376953}, "layer_norm": {"bias": 2.22525691986084, "scale": 28.246891021728516}}, "42": {"attention": {"k_proj": {"bias": 0.8601109981536865, "kernel": 38.31235885620117}, "out_proj": {"bias": 1.4427157640457153, "kernel": 45.07648849487305}, "q_proj": {"bias": 1.549715280532837, "kernel": 39.524009704589844}, "v_proj": {"bias": 0.6933339834213257, "kernel": 43.49076461791992}}, "feed_forward": {"intermediate_dense": {"bias": 1.9107489585876465, "kernel": 88.009765625}, "output_dense": {"bias": 1.1978566646575928, "kernel": 95.7593994140625}}, "final_layer_norm": {"bias": 1.9227323532104492, "scale": 29.817535400390625}, "layer_norm": {"bias": 1.6761282682418823, "scale": 26.810440063476562}}, "43": {"attention": {"k_proj": {"bias": 1.247081995010376, "kernel": 34.694725036621094}, "out_proj": {"bias": 1.4174811840057373, "kernel": 41.36320495605469}, "q_proj": {"bias": 1.3773530721664429, "kernel": 35.38981628417969}, "v_proj": {"bias": 0.5787136554718018, "kernel": 39.29212951660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8936835527420044, "kernel": 87.073974609375}, "output_dense": {"bias": 0.9419379234313965, "kernel": 93.74283599853516}}, "final_layer_norm": {"bias": 1.99924898147583, "scale": 32.0491943359375}, "layer_norm": {"bias": 1.7940990924835205, "scale": 25.131242752075195}}, "44": {"attention": {"k_proj": {"bias": 2.5188145637512207, "kernel": 35.16314697265625}, "out_proj": {"bias": 1.1667875051498413, "kernel": 45.019126892089844}, "q_proj": {"bias": 1.317202091217041, "kernel": 35.58300018310547}, "v_proj": {"bias": 0.38874924182891846, "kernel": 44.14718246459961}}, "feed_forward": {"intermediate_dense": {"bias": 1.9462898969650269, "kernel": 86.07953643798828}, "output_dense": {"bias": 0.859266996383667, "kernel": 91.58765411376953}}, "final_layer_norm": {"bias": 2.0454273223876953, "scale": 34.2881965637207}, "layer_norm": {"bias": 1.6815991401672363, "scale": 25.14142608642578}}, "45": {"attention": {"k_proj": {"bias": 2.081407308578491, "kernel": 34.86139678955078}, "out_proj": {"bias": 1.0356104373931885, "kernel": 48.59937286376953}, "q_proj": {"bias": 1.402512788772583, "kernel": 35.03264617919922}, "v_proj": {"bias": 0.4231463074684143, "kernel": 48.76853942871094}}, "feed_forward": {"intermediate_dense": {"bias": 2.016927719116211, "kernel": 82.93773651123047}, "output_dense": {"bias": 0.9764893054962158, "kernel": 87.24796295166016}}, "final_layer_norm": {"bias": 1.9180456399917603, "scale": 33.143672943115234}, "layer_norm": {"bias": 1.5726068019866943, "scale": 23.782546997070312}}, "46": {"attention": {"k_proj": {"bias": 1.5659263134002686, "kernel": 35.878021240234375}, "out_proj": {"bias": 0.8182340264320374, "kernel": 51.16078186035156}, "q_proj": {"bias": 1.5642974376678467, "kernel": 36.18907165527344}, "v_proj": {"bias": 0.4092414081096649, "kernel": 51.89159393310547}}, "feed_forward": {"intermediate_dense": {"bias": 2.0093321800231934, "kernel": 77.47581481933594}, "output_dense": {"bias": 1.1406863927841187, "kernel": 77.73695373535156}}, "final_layer_norm": {"bias": 1.8108854293823242, "scale": 28.70657730102539}, "layer_norm": {"bias": 1.3991491794586182, "scale": 22.808137893676758}}, "47": {"attention": {"k_proj": {"bias": 0.6173280477523804, "kernel": 38.678985595703125}, "out_proj": {"bias": 0.6758822202682495, "kernel": 46.45281219482422}, "q_proj": {"bias": 1.7084776163101196, "kernel": 39.426841735839844}, "v_proj": {"bias": 0.4932914674282074, "kernel": 47.617279052734375}}, "feed_forward": {"intermediate_dense": {"bias": 1.986911654472351, "kernel": 75.47482299804688}, "output_dense": {"bias": 0.6346586346626282, "kernel": 72.82707214355469}}, "final_layer_norm": {"bias": 1.1888140439987183, "scale": 23.650447845458984}, "layer_norm": {"bias": 1.2521969079971313, "scale": 20.66573715209961}}, "5": {"attention": {"k_proj": {"bias": 0.42588678002357483, "kernel": 50.1945686340332}, "out_proj": {"bias": 1.6038882732391357, "kernel": 51.2144889831543}, "q_proj": {"bias": 2.7522244453430176, "kernel": 50.37500762939453}, "v_proj": {"bias": 0.3343381881713867, "kernel": 51.71652603149414}}, "feed_forward": {"intermediate_dense": {"bias": 1.8887722492218018, "kernel": 104.60663604736328}, "output_dense": {"bias": 0.8976269960403442, "kernel": 94.77360534667969}}, "final_layer_norm": {"bias": 2.1965675354003906, "scale": 21.37998390197754}, "layer_norm": {"bias": 2.0435237884521484, "scale": 22.437192916870117}}, "6": {"attention": {"k_proj": {"bias": 0.4843112528324127, "kernel": 51.87700653076172}, "out_proj": {"bias": 1.5925445556640625, "kernel": 50.83113479614258}, "q_proj": {"bias": 2.7889723777770996, "kernel": 52.3514404296875}, "v_proj": {"bias": 0.3247200846672058, "kernel": 51.107398986816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.8638136386871338, "kernel": 103.7142333984375}, "output_dense": {"bias": 0.752193808555603, "kernel": 94.57742309570312}}, "final_layer_norm": {"bias": 2.5145251750946045, "scale": 20.836563110351562}, "layer_norm": {"bias": 2.0285890102386475, "scale": 23.156789779663086}}, "7": {"attention": {"k_proj": {"bias": 0.5048109889030457, "kernel": 51.46453094482422}, "out_proj": {"bias": 1.4398455619812012, "kernel": 51.139068603515625}, "q_proj": {"bias": 2.550907611846924, "kernel": 51.92047119140625}, "v_proj": {"bias": 0.42719271779060364, "kernel": 50.953285217285156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8739449977874756, "kernel": 103.49991607666016}, "output_dense": {"bias": 0.5876641273498535, "kernel": 94.39387512207031}}, "final_layer_norm": {"bias": 2.416801929473877, "scale": 21.010677337646484}, "layer_norm": {"bias": 1.9788501262664795, "scale": 22.19708824157715}}, "8": {"attention": {"k_proj": {"bias": 0.49711495637893677, "kernel": 51.12122344970703}, "out_proj": {"bias": 1.2548246383666992, "kernel": 51.65118408203125}, "q_proj": {"bias": 2.541980504989624, "kernel": 51.03407287597656}, "v_proj": {"bias": 0.35420340299606323, "kernel": 51.662872314453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9298467636108398, "kernel": 103.21916198730469}, "output_dense": {"bias": 0.5482766628265381, "kernel": 93.97574615478516}}, "final_layer_norm": {"bias": 2.3526053428649902, "scale": 20.7393856048584}, "layer_norm": {"bias": 1.9221248626708984, "scale": 22.40435028076172}}, "9": {"attention": {"k_proj": {"bias": 0.5231171250343323, "kernel": 52.01068878173828}, "out_proj": {"bias": 1.4968843460083008, "kernel": 52.671897888183594}, "q_proj": {"bias": 2.4629459381103516, "kernel": 52.26807403564453}, "v_proj": {"bias": 0.38445231318473816, "kernel": 52.86597442626953}}, "feed_forward": {"intermediate_dense": {"bias": 2.026733875274658, "kernel": 101.98421478271484}, "output_dense": {"bias": 0.6828575134277344, "kernel": 94.36962890625}}, "final_layer_norm": {"bias": 2.325080156326294, "scale": 20.160720825195312}, "layer_norm": {"bias": 2.0236480236053467, "scale": 24.083864212036133}}}, "pos_conv_embed": {"conv": {"bias": 5.847014427185059, "weight_g": 9.12463665008545, "weight_v": 93.52015686035156}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.376383304595947, "scale": 16.443069458007812}, "projection": {"bias": 1.8670344352722168, "kernel": 37.218414306640625}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 1.9151924789184704e-05, "train/loss": 0.204779714345932, "train/param_norm": 1241.662353515625, "_runtime": 5032, "_timestamp": 1660135191, "_step": 275600, "_wandb": {"runtime": 5033}} \ No newline at end of file diff --git a/wandb/run-20220810_111559-290849gb/logs/debug-internal.log b/wandb/run-20220810_111559-290849gb/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..5797985128573c991b2a307adb276c43f1b5f324 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6262ea728ba649f48f692f8eb4a7d194e2bcdfe33cede302167d1a4d75ecc09 +size 160449 diff --git a/wandb/run-20220810_111559-290849gb/logs/debug.log b/wandb/run-20220810_111559-290849gb/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..7e0b127024c5ee54238271ce778e50a41a2fee6b --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1e44b22c4702845ce31095607b3ab0c73ebedf10e949748067ba3e70053ba0e +size 6378 diff --git a/wandb/run-20220810_111559-290849gb/run-290849gb.wandb b/wandb/run-20220810_111559-290849gb/run-290849gb.wandb new file mode 100644 index 0000000000000000000000000000000000000000..a79456cd800546354c0c2a4c530624a2e4f1caa1 --- /dev/null +++ b/wandb/run-20220810_111559-290849gb/run-290849gb.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f22acfd4deb69c3ccc7826aa9a440ce8f7296ac3a483b3ed16e5eeab29a79ac +size 757033 diff --git a/wandb/run-20220810_145446-1k92sv35/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220810_145446-1k92sv35/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0f5fc813b76949369e0fe19e8facee530f724a --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1632 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + # Create sampling rng + rng, input_rng = jax.random.split(rng) + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + p_train_step.clear_cache() + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220810_145446-1k92sv35/files/config.yaml b/wandb/run-20220810_145446-1k92sv35/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ea95760fe00138e4ebc578177ea87ba702cc248 --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660143286 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220810_145446-1k92sv35/files/diff.patch b/wandb/run-20220810_145446-1k92sv35/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..13ed34f2b2af595cb1b49926adb1467f57c1469c --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/diff.patch @@ -0,0 +1,132 @@ +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..6891af1 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -10,10 +10,9 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --num_train_epochs="40" \ + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ +- --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --gradient_accumulation_steps="2" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..4d0f5fc 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1520,11 +1521,10 @@ def main(): + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +@@ -1559,6 +1559,7 @@ def main(): + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) ++ p_train_step.clear_cache() + + if cur_step % total_train_steps == 0: + break +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..cc1961e 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,34 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..aef858a 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220810_145446-1k92sv35/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..0d5686d 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220810_145446-1k92sv35/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..3128ad6 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220810_145446-1k92sv35 +\ No newline at end of file diff --git a/wandb/run-20220810_145446-1k92sv35/files/output.log b/wandb/run-20220810_145446-1k92sv35/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..e69a92165b2efe8b5954ac254212e1efc28b514d --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c135f3ac8ed3bef82639474fe6693aa305cdd702fc964877622ffc3ae9ce5ce9 +size 224313 diff --git a/wandb/run-20220810_145446-1k92sv35/files/requirements.txt b/wandb/run-20220810_145446-1k92sv35/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220810_145446-1k92sv35/files/wandb-metadata.json b/wandb/run-20220810_145446-1k92sv35/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..78cc0f94dc9c6ecd54722eecf08be79cfe3fa9da --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/wandb-metadata.json @@ -0,0 +1,69 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-10T14:54:50.575340", + "startedAt": "2022-08-10T14:54:46.729335", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=./", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=2", + "--precision=half_mixed", + "--matmul_precision=bfloat16", + "--learning_rate=6.394633237505332e-05", + "--skip_steps=275000", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=30.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--do_lower_case", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)", + "--remove_punctuation" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220810_145446-1k92sv35/files/wandb-summary.json b/wandb/run-20220810_145446-1k92sv35/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..5f7a19ba7c87aa5e39a0ed8ea16b1e0a8be103e5 --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/files/wandb-summary.json @@ -0,0 +1 @@ +{"_wandb": {"runtime": 727}} \ No newline at end of file diff --git a/wandb/run-20220810_145446-1k92sv35/logs/debug-internal.log b/wandb/run-20220810_145446-1k92sv35/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..2b3c0c66138bc3a2e8a7e17d54289a371cadfb7d --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d064b2b531d3a5823be66688cf52c5cc45e8f453efd03f83df0148c4827f85db +size 43560 diff --git a/wandb/run-20220810_145446-1k92sv35/logs/debug.log b/wandb/run-20220810_145446-1k92sv35/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..1c40c260fd66a16523977bf5727e411cdbf8a073 --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f46d292d19f5c77a9812e097b2c3160a150959386c7d3ccc7915aca3eb061632 +size 6071 diff --git a/wandb/run-20220810_145446-1k92sv35/run-1k92sv35.wandb b/wandb/run-20220810_145446-1k92sv35/run-1k92sv35.wandb new file mode 100644 index 0000000000000000000000000000000000000000..34b6d639b2657fa7949e794b754cacdd04c72ce9 --- /dev/null +++ b/wandb/run-20220810_145446-1k92sv35/run-1k92sv35.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b8052ef4aa82bb8b1afd0f2976795c9951b74f6c9bddf669d66afdd3bafdb85 +size 238991 diff --git a/wandb/run-20220810_151736-2jo5la5b/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220810_151736-2jo5la5b/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0f5fc813b76949369e0fe19e8facee530f724a --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1632 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + # Create sampling rng + rng, input_rng = jax.random.split(rng) + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + p_train_step.clear_cache() + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220810_151736-2jo5la5b/files/config.yaml b/wandb/run-20220810_151736-2jo5la5b/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..540dd59c53bea01d8da34d8d05810e1b32ed166d --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660144656 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220810_151736-2jo5la5b/files/diff.patch b/wandb/run-20220810_151736-2jo5la5b/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..adad3d94dbb1b624fde06635a86d4d5e00cbb316 --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/diff.patch @@ -0,0 +1,144 @@ +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..632a336 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..4d0f5fc 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1520,11 +1521,10 @@ def main(): + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +@@ -1559,6 +1559,7 @@ def main(): + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) ++ p_train_step.clear_cache() + + if cur_step % total_train_steps == 0: + break +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..3c0d148 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,48 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..90a074d 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220810_151736-2jo5la5b/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..de899a6 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220810_151736-2jo5la5b/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..0dfb7e0 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220810_151736-2jo5la5b +\ No newline at end of file diff --git a/wandb/run-20220810_151736-2jo5la5b/files/output.log b/wandb/run-20220810_151736-2jo5la5b/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..003c11658d9dd4b57af1f298cc91a6c0738b56b5 --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89c3976bd22db27a53b27fd3d63f60a6563116c92d804aceb8ada0bf7909833f +size 224905 diff --git a/wandb/run-20220810_151736-2jo5la5b/files/requirements.txt b/wandb/run-20220810_151736-2jo5la5b/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220810_151736-2jo5la5b/files/wandb-metadata.json b/wandb/run-20220810_151736-2jo5la5b/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..07fcd5fac35c4a9d340c8996e52f0e8ea45c09be --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/wandb-metadata.json @@ -0,0 +1,69 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-10T15:17:39.930151", + "startedAt": "2022-08-10T15:17:36.501050", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=./", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=half_mixed", + "--matmul_precision=bfloat16", + "--learning_rate=6.394633237505332e-05", + "--skip_steps=275000", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=30.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--do_lower_case", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)", + "--remove_punctuation" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220810_151736-2jo5la5b/files/wandb-summary.json b/wandb/run-20220810_151736-2jo5la5b/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..4e50de9097cd7f498c9c2ffa8e8adf34a9960dad --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/files/wandb-summary.json @@ -0,0 +1 @@ +{"train/grad_norm": 4.261558532714844, "layer_grad_norm/": {"lm_head": {"bias": 0.017788879573345184, "kernel": 2.3992724418640137}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.0323975533246994, "scale": 0.035100195556879044}, "layers": {"0": {"attention": {"k_proj": {"bias": 6.96770366630517e-05, "kernel": 0.017761101946234703}, "out_proj": {"bias": 0.025619633495807648, "kernel": 0.16803325712680817}, "q_proj": {"bias": 0.0018031439976766706, "kernel": 0.022346343845129013}, "v_proj": {"bias": 0.021983297541737556, "kernel": 0.16035471856594086}}, "feed_forward": {"intermediate_dense": {"bias": 0.0326140932738781, "kernel": 0.39286553859710693}, "output_dense": {"bias": 0.010046987794339657, "kernel": 0.32385218143463135}}, "final_layer_norm": {"bias": 0.09299281984567642, "scale": 0.26545006036758423}, "layer_norm": {"bias": 0.04185380041599274, "scale": 0.045830436050891876}}, "1": {"attention": {"k_proj": {"bias": 4.200146213406697e-05, "kernel": 0.02062308043241501}, "out_proj": {"bias": 0.011040832847356796, "kernel": 0.10348627716302872}, "q_proj": {"bias": 0.001808196771889925, "kernel": 0.02417493239045143}, "v_proj": {"bias": 0.016231240704655647, "kernel": 0.07038241624832153}}, "feed_forward": {"intermediate_dense": {"bias": 0.014394059777259827, "kernel": 0.22140879929065704}, "output_dense": {"bias": 0.009978776797652245, "kernel": 0.18486550450325012}}, "final_layer_norm": {"bias": 0.02638934552669525, "scale": 0.022694334387779236}, "layer_norm": {"bias": 0.0309704951941967, "scale": 0.01734926551580429}}, "10": {"attention": {"k_proj": {"bias": 2.6716719730757177e-05, "kernel": 0.07760775834321976}, "out_proj": {"bias": 0.009652212262153625, "kernel": 0.1277025192975998}, "q_proj": {"bias": 0.005531705915927887, "kernel": 0.08602005243301392}, "v_proj": {"bias": 0.014598616398870945, "kernel": 0.11959479004144669}}, "feed_forward": {"intermediate_dense": {"bias": 0.014417275786399841, "kernel": 0.2236904352903366}, "output_dense": {"bias": 0.009385183453559875, "kernel": 0.1884544938802719}}, "final_layer_norm": {"bias": 0.024149442091584206, "scale": 0.02101276069879532}, "layer_norm": {"bias": 0.03126443177461624, "scale": 0.03499530255794525}}, "11": {"attention": {"k_proj": {"bias": 6.491513340733945e-05, "kernel": 0.1055336743593216}, "out_proj": {"bias": 0.009826489724218845, "kernel": 0.22929716110229492}, "q_proj": {"bias": 0.006276742555201054, "kernel": 0.09677979350090027}, "v_proj": {"bias": 0.016340602189302444, "kernel": 0.19895723462104797}}, "feed_forward": {"intermediate_dense": {"bias": 0.013293663039803505, "kernel": 0.21519728004932404}, "output_dense": {"bias": 0.009642157703638077, "kernel": 0.17639335989952087}}, "final_layer_norm": {"bias": 0.022317882627248764, "scale": 0.020071465522050858}, "layer_norm": {"bias": 0.034088365733623505, "scale": 0.038652658462524414}}, "12": {"attention": {"k_proj": {"bias": 3.970207035308704e-05, "kernel": 0.10333496332168579}, "out_proj": {"bias": 0.009958583861589432, "kernel": 0.1805419921875}, "q_proj": {"bias": 0.0065935952588915825, "kernel": 0.09627208113670349}, "v_proj": {"bias": 0.01603534072637558, "kernel": 0.16809365153312683}}, "feed_forward": {"intermediate_dense": {"bias": 0.014714915305376053, "kernel": 0.2288183867931366}, "output_dense": {"bias": 0.009768988937139511, "kernel": 0.1851433366537094}}, "final_layer_norm": {"bias": 0.02528025582432747, "scale": 0.021869206801056862}, "layer_norm": {"bias": 0.03297863155603409, "scale": 0.03648758307099342}}, "13": {"attention": {"k_proj": {"bias": 7.583085243823007e-05, "kernel": 0.09878440946340561}, "out_proj": {"bias": 0.010364462621510029, "kernel": 0.24141821265220642}, "q_proj": {"bias": 0.006039786152541637, "kernel": 0.0987459123134613}, "v_proj": {"bias": 0.01709395833313465, "kernel": 0.2135884016752243}}, "feed_forward": {"intermediate_dense": {"bias": 0.01545906811952591, "kernel": 0.24038702249526978}, "output_dense": {"bias": 0.010241981595754623, "kernel": 0.20683333277702332}}, "final_layer_norm": {"bias": 0.02676304057240486, "scale": 0.025948306545615196}, "layer_norm": {"bias": 0.03223804384469986, "scale": 0.028945578262209892}}, "14": {"attention": {"k_proj": {"bias": 4.3531006667762995e-05, "kernel": 0.07849182188510895}, "out_proj": {"bias": 0.010953404009342194, "kernel": 0.1908419132232666}, "q_proj": {"bias": 0.0049440511502325535, "kernel": 0.07795748114585876}, "v_proj": {"bias": 0.01730860397219658, "kernel": 0.17354583740234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.01733974926173687, "kernel": 0.2709116041660309}, "output_dense": {"bias": 0.010775217786431313, "kernel": 0.23137901723384857}}, "final_layer_norm": {"bias": 0.03220454603433609, "scale": 0.03206019848585129}, "layer_norm": {"bias": 0.03320500999689102, "scale": 0.02874942496418953}}, "15": {"attention": {"k_proj": {"bias": 0.0001445362577214837, "kernel": 0.2262178361415863}, "out_proj": {"bias": 0.009758997708559036, "kernel": 0.3166901469230652}, "q_proj": {"bias": 0.01526635978370905, "kernel": 0.24137938022613525}, "v_proj": {"bias": 0.016232844442129135, "kernel": 0.23615804314613342}}, "feed_forward": {"intermediate_dense": {"bias": 0.015659993514418602, "kernel": 0.24640920758247375}, "output_dense": {"bias": 0.009648777544498444, "kernel": 0.22626811265945435}}, "final_layer_norm": {"bias": 0.02658914588391781, "scale": 0.024082619696855545}, "layer_norm": {"bias": 0.039305880665779114, "scale": 0.047641269862651825}}, "16": {"attention": {"k_proj": {"bias": 6.230986036825925e-05, "kernel": 0.11781159043312073}, "out_proj": {"bias": 0.00993018876761198, "kernel": 0.22083336114883423}, "q_proj": {"bias": 0.007844965904951096, "kernel": 0.12069929391145706}, "v_proj": {"bias": 0.015241955406963825, "kernel": 0.17874589562416077}}, "feed_forward": {"intermediate_dense": {"bias": 0.015298807062208652, "kernel": 0.24970686435699463}, "output_dense": {"bias": 0.009682225994765759, "kernel": 0.2286539077758789}}, "final_layer_norm": {"bias": 0.027272606268525124, "scale": 0.02279319055378437}, "layer_norm": {"bias": 0.030095037072896957, "scale": 0.02651960961520672}}, "17": {"attention": {"k_proj": {"bias": 2.7123965992359444e-05, "kernel": 0.08686056733131409}, "out_proj": {"bias": 0.010655339807271957, "kernel": 0.1433088779449463}, "q_proj": {"bias": 0.00538025563582778, "kernel": 0.08902112394571304}, "v_proj": {"bias": 0.016178736463189125, "kernel": 0.1543588638305664}}, "feed_forward": {"intermediate_dense": {"bias": 0.01603703759610653, "kernel": 0.26329636573791504}, "output_dense": {"bias": 0.010479980148375034, "kernel": 0.23315469920635223}}, "final_layer_norm": {"bias": 0.026221446692943573, "scale": 0.02159356325864792}, "layer_norm": {"bias": 0.0314837209880352, "scale": 0.023964006453752518}}, "18": {"attention": {"k_proj": {"bias": 6.195436435518786e-05, "kernel": 0.1987147331237793}, "out_proj": {"bias": 0.01037890650331974, "kernel": 0.2379658818244934}, "q_proj": {"bias": 0.01076379232108593, "kernel": 0.1747097373008728}, "v_proj": {"bias": 0.015858840197324753, "kernel": 0.1839509755373001}}, "feed_forward": {"intermediate_dense": {"bias": 0.015994738787412643, "kernel": 0.2778272330760956}, "output_dense": {"bias": 0.010189207270741463, "kernel": 0.24560357630252838}}, "final_layer_norm": {"bias": 0.02731412649154663, "scale": 0.024946974590420723}, "layer_norm": {"bias": 0.036750297993421555, "scale": 0.046463578939437866}}, "19": {"attention": {"k_proj": {"bias": 5.6334753026021644e-05, "kernel": 0.10678335279226303}, "out_proj": {"bias": 0.010594839230179787, "kernel": 0.1787024736404419}, "q_proj": {"bias": 0.006550152786076069, "kernel": 0.10678353905677795}, "v_proj": {"bias": 0.01596170663833618, "kernel": 0.15820428729057312}}, "feed_forward": {"intermediate_dense": {"bias": 0.01575728878378868, "kernel": 0.2879087030887604}, "output_dense": {"bias": 0.010250548832118511, "kernel": 0.25105780363082886}}, "final_layer_norm": {"bias": 0.02602769061923027, "scale": 0.027589356526732445}, "layer_norm": {"bias": 0.028630346059799194, "scale": 0.029089417308568954}}, "2": {"attention": {"k_proj": {"bias": 3.820758865913376e-05, "kernel": 0.03627895936369896}, "out_proj": {"bias": 0.011254100129008293, "kernel": 0.12343056499958038}, "q_proj": {"bias": 0.0029873033054172993, "kernel": 0.03826657682657242}, "v_proj": {"bias": 0.01840016432106495, "kernel": 0.10382330417633057}}, "feed_forward": {"intermediate_dense": {"bias": 0.014908352866768837, "kernel": 0.2502833604812622}, "output_dense": {"bias": 0.010285001248121262, "kernel": 0.17825302481651306}}, "final_layer_norm": {"bias": 0.02979440428316593, "scale": 0.02249833196401596}, "layer_norm": {"bias": 0.03277341276407242, "scale": 0.050777725875377655}}, "20": {"attention": {"k_proj": {"bias": 1.5044915016915184e-05, "kernel": 0.05861634016036987}, "out_proj": {"bias": 0.011103209108114243, "kernel": 0.0989118367433548}, "q_proj": {"bias": 0.002783840987831354, "kernel": 0.057587526738643646}, "v_proj": {"bias": 0.01630830019712448, "kernel": 0.10167922079563141}}, "feed_forward": {"intermediate_dense": {"bias": 0.01605488918721676, "kernel": 0.2995821237564087}, "output_dense": {"bias": 0.010747051797807217, "kernel": 0.25055864453315735}}, "final_layer_norm": {"bias": 0.027223603799939156, "scale": 0.030495693907141685}, "layer_norm": {"bias": 0.028251871466636658, "scale": 0.023164357990026474}}, "21": {"attention": {"k_proj": {"bias": 2.638094338180963e-05, "kernel": 0.08226853609085083}, "out_proj": {"bias": 0.011270384304225445, "kernel": 0.15913301706314087}, "q_proj": {"bias": 0.004133924841880798, "kernel": 0.08240459114313126}, "v_proj": {"bias": 0.017001166939735413, "kernel": 0.1452832967042923}}, "feed_forward": {"intermediate_dense": {"bias": 0.016547638922929764, "kernel": 0.31675517559051514}, "output_dense": {"bias": 0.011030763387680054, "kernel": 0.2627480626106262}}, "final_layer_norm": {"bias": 0.027790464460849762, "scale": 0.03074125573039055}, "layer_norm": {"bias": 0.028280116617679596, "scale": 0.02932114712893963}}, "22": {"attention": {"k_proj": {"bias": 2.3005515686236322e-05, "kernel": 0.08692421019077301}, "out_proj": {"bias": 0.011904345825314522, "kernel": 0.13713234663009644}, "q_proj": {"bias": 0.004867547657340765, "kernel": 0.08724722266197205}, "v_proj": {"bias": 0.017674535512924194, "kernel": 0.1340729296207428}}, "feed_forward": {"intermediate_dense": {"bias": 0.018127482384443283, "kernel": 0.34236806631088257}, "output_dense": {"bias": 0.011649074032902718, "kernel": 0.27650779485702515}}, "final_layer_norm": {"bias": 0.032341137528419495, "scale": 0.03339264914393425}, "layer_norm": {"bias": 0.03334794566035271, "scale": 0.03606652468442917}}, "23": {"attention": {"k_proj": {"bias": 0.00014989564078859985, "kernel": 0.14326578378677368}, "out_proj": {"bias": 0.012433766387403011, "kernel": 0.33488672971725464}, "q_proj": {"bias": 0.006895859260112047, "kernel": 0.13656491041183472}, "v_proj": {"bias": 0.020493322983384132, "kernel": 0.2837929129600525}}, "feed_forward": {"intermediate_dense": {"bias": 0.018564969301223755, "kernel": 0.3712247312068939}, "output_dense": {"bias": 0.012342891655862331, "kernel": 0.28736400604248047}}, "final_layer_norm": {"bias": 0.034619107842445374, "scale": 0.03262517601251602}, "layer_norm": {"bias": 0.0374518446624279, "scale": 0.04078807309269905}}, "24": {"attention": {"k_proj": {"bias": 7.870129775255919e-05, "kernel": 0.17635203897953033}, "out_proj": {"bias": 0.0117869907990098, "kernel": 0.28269436955451965}, "q_proj": {"bias": 0.009934404864907265, "kernel": 0.17477506399154663}, "v_proj": {"bias": 0.020976759493350983, "kernel": 0.2593054473400116}}, "feed_forward": {"intermediate_dense": {"bias": 0.018721435219049454, "kernel": 0.39080867171287537}, "output_dense": {"bias": 0.011417874135077, "kernel": 0.266484797000885}}, "final_layer_norm": {"bias": 0.03424592316150665, "scale": 0.03418727219104767}, "layer_norm": {"bias": 0.04514043405652046, "scale": 0.03920702636241913}}, "25": {"attention": {"k_proj": {"bias": 6.484580808319151e-05, "kernel": 0.14558939635753632}, "out_proj": {"bias": 0.012111745774745941, "kernel": 0.24085207283496857}, "q_proj": {"bias": 0.006976753007620573, "kernel": 0.1353585571050644}, "v_proj": {"bias": 0.02140776999294758, "kernel": 0.2526576519012451}}, "feed_forward": {"intermediate_dense": {"bias": 0.01935059204697609, "kernel": 0.402408629655838}, "output_dense": {"bias": 0.011820340529084206, "kernel": 0.25476741790771484}}, "final_layer_norm": {"bias": 0.036345649510622025, "scale": 0.03660818934440613}, "layer_norm": {"bias": 0.040864236652851105, "scale": 0.036443017423152924}}, "26": {"attention": {"k_proj": {"bias": 8.441523823421448e-05, "kernel": 0.1542394459247589}, "out_proj": {"bias": 0.011885873973369598, "kernel": 0.2426203340291977}, "q_proj": {"bias": 0.007985231466591358, "kernel": 0.15195505321025848}, "v_proj": {"bias": 0.021539820358157158, "kernel": 0.2561507225036621}}, "feed_forward": {"intermediate_dense": {"bias": 0.018558435142040253, "kernel": 0.3671436011791229}, "output_dense": {"bias": 0.011674494482576847, "kernel": 0.25239574909210205}}, "final_layer_norm": {"bias": 0.03317685052752495, "scale": 0.037001386284828186}, "layer_norm": {"bias": 0.04099573194980621, "scale": 0.03664571791887283}}, "27": {"attention": {"k_proj": {"bias": 0.00012791654444299638, "kernel": 0.2498001754283905}, "out_proj": {"bias": 0.010986501350998878, "kernel": 0.3230920433998108}, "q_proj": {"bias": 0.012237230315804482, "kernel": 0.2323647141456604}, "v_proj": {"bias": 0.02166379615664482, "kernel": 0.32181301712989807}}, "feed_forward": {"intermediate_dense": {"bias": 0.019342927262187004, "kernel": 0.37609928846359253}, "output_dense": {"bias": 0.010700833052396774, "kernel": 0.26553863286972046}}, "final_layer_norm": {"bias": 0.03730005770921707, "scale": 0.03432407230138779}, "layer_norm": {"bias": 0.04656531661748886, "scale": 0.057998575270175934}}, "28": {"attention": {"k_proj": {"bias": 0.00011714182619471103, "kernel": 0.2847234308719635}, "out_proj": {"bias": 0.009179851040244102, "kernel": 0.2750641107559204}, "q_proj": {"bias": 0.013206155970692635, "kernel": 0.2713070213794708}, "v_proj": {"bias": 0.018852179870009422, "kernel": 0.284795343875885}}, "feed_forward": {"intermediate_dense": {"bias": 0.016067277640104294, "kernel": 0.32464420795440674}, "output_dense": {"bias": 0.008993763476610184, "kernel": 0.22799071669578552}}, "final_layer_norm": {"bias": 0.029480326920747757, "scale": 0.03185226768255234}, "layer_norm": {"bias": 0.04435642808675766, "scale": 0.04779055714607239}}, "29": {"attention": {"k_proj": {"bias": 0.00010104493412654847, "kernel": 0.18129537999629974}, "out_proj": {"bias": 0.00840977393090725, "kernel": 0.23119071125984192}, "q_proj": {"bias": 0.007738973014056683, "kernel": 0.175735205411911}, "v_proj": {"bias": 0.017744898796081543, "kernel": 0.26999735832214355}}, "feed_forward": {"intermediate_dense": {"bias": 0.01636805199086666, "kernel": 0.35230132937431335}, "output_dense": {"bias": 0.007491791155189276, "kernel": 0.20692045986652374}}, "final_layer_norm": {"bias": 0.030294589698314667, "scale": 0.03559590131044388}, "layer_norm": {"bias": 0.04036903753876686, "scale": 0.03787432610988617}}, "3": {"attention": {"k_proj": {"bias": 7.949249993544072e-05, "kernel": 0.05786384642124176}, "out_proj": {"bias": 0.011071368120610714, "kernel": 0.20154500007629395}, "q_proj": {"bias": 0.004657139535993338, "kernel": 0.054560501128435135}, "v_proj": {"bias": 0.01717137172818184, "kernel": 0.14990991353988647}}, "feed_forward": {"intermediate_dense": {"bias": 0.015468303114175797, "kernel": 0.243474543094635}, "output_dense": {"bias": 0.010136950761079788, "kernel": 0.17860807478427887}}, "final_layer_norm": {"bias": 0.02905898541212082, "scale": 0.022218134254217148}, "layer_norm": {"bias": 0.03191278874874115, "scale": 0.03522134944796562}}, "30": {"attention": {"k_proj": {"bias": 9.62297708611004e-05, "kernel": 0.17377516627311707}, "out_proj": {"bias": 0.007276182062923908, "kernel": 0.2005656659603119}, "q_proj": {"bias": 0.007119656540453434, "kernel": 0.1555549055337906}, "v_proj": {"bias": 0.014824660494923592, "kernel": 0.2199939787387848}}, "feed_forward": {"intermediate_dense": {"bias": 0.013565706089138985, "kernel": 0.3095209300518036}, "output_dense": {"bias": 0.006479381583631039, "kernel": 0.16640284657478333}}, "final_layer_norm": {"bias": 0.024809623137116432, "scale": 0.025144733488559723}, "layer_norm": {"bias": 0.02748488262295723, "scale": 0.042448919266462326}}, "31": {"attention": {"k_proj": {"bias": 7.348716462729499e-05, "kernel": 0.1296117603778839}, "out_proj": {"bias": 0.006049656309187412, "kernel": 0.14480695128440857}, "q_proj": {"bias": 0.005557514727115631, "kernel": 0.11993050575256348}, "v_proj": {"bias": 0.01159297488629818, "kernel": 0.17019638419151306}}, "feed_forward": {"intermediate_dense": {"bias": 0.011785638518631458, "kernel": 0.26489073038101196}, "output_dense": {"bias": 0.005397414322942495, "kernel": 0.13867370784282684}}, "final_layer_norm": {"bias": 0.021065887063741684, "scale": 0.021553657948970795}, "layer_norm": {"bias": 0.022535139694809914, "scale": 0.024462278932332993}}, "32": {"attention": {"k_proj": {"bias": 7.606489089084789e-05, "kernel": 0.1093662902712822}, "out_proj": {"bias": 0.005289402790367603, "kernel": 0.12852157652378082}, "q_proj": {"bias": 0.004667255096137524, "kernel": 0.10470472276210785}, "v_proj": {"bias": 0.009674660861492157, "kernel": 0.14767077565193176}}, "feed_forward": {"intermediate_dense": {"bias": 0.009607319720089436, "kernel": 0.22237342596054077}, "output_dense": {"bias": 0.004497535061091185, "kernel": 0.11611822247505188}}, "final_layer_norm": {"bias": 0.018468791618943214, "scale": 0.01746383309364319}, "layer_norm": {"bias": 0.01872493326663971, "scale": 0.022064995020627975}}, "33": {"attention": {"k_proj": {"bias": 4.900983549305238e-05, "kernel": 0.10418825596570969}, "out_proj": {"bias": 0.004187482409179211, "kernel": 0.11690917611122131}, "q_proj": {"bias": 0.00446582306176424, "kernel": 0.09962163120508194}, "v_proj": {"bias": 0.008225696161389351, "kernel": 0.13791730999946594}}, "feed_forward": {"intermediate_dense": {"bias": 0.007958542555570602, "kernel": 0.1807086020708084}, "output_dense": {"bias": 0.0035951065365225077, "kernel": 0.09710787236690521}}, "final_layer_norm": {"bias": 0.01612575724720955, "scale": 0.01639324426651001}, "layer_norm": {"bias": 0.016356293112039566, "scale": 0.01753804460167885}}, "34": {"attention": {"k_proj": {"bias": 4.477999755181372e-05, "kernel": 0.1139673963189125}, "out_proj": {"bias": 0.0032220929861068726, "kernel": 0.09854735434055328}, "q_proj": {"bias": 0.0044842008501291275, "kernel": 0.10058338940143585}, "v_proj": {"bias": 0.006300156936049461, "kernel": 0.11815305054187775}}, "feed_forward": {"intermediate_dense": {"bias": 0.006436260882765055, "kernel": 0.1448836326599121}, "output_dense": {"bias": 0.002822649199515581, "kernel": 0.08330518007278442}}, "final_layer_norm": {"bias": 0.013203151524066925, "scale": 0.014038017019629478}, "layer_norm": {"bias": 0.014050491154193878, "scale": 0.0185261033475399}}, "35": {"attention": {"k_proj": {"bias": 6.918576400494203e-05, "kernel": 0.10062655806541443}, "out_proj": {"bias": 0.002610743511468172, "kernel": 0.09429289400577545}, "q_proj": {"bias": 0.004383553750813007, "kernel": 0.09823683649301529}, "v_proj": {"bias": 0.004896542057394981, "kernel": 0.0966302752494812}}, "feed_forward": {"intermediate_dense": {"bias": 0.0051103802397847176, "kernel": 0.11588157713413239}, "output_dense": {"bias": 0.002347193658351898, "kernel": 0.06901519000530243}}, "final_layer_norm": {"bias": 0.010162574239075184, "scale": 0.01065899059176445}, "layer_norm": {"bias": 0.011963584460318089, "scale": 0.016352832317352295}}, "36": {"attention": {"k_proj": {"bias": 6.899447180330753e-05, "kernel": 0.0767166018486023}, "out_proj": {"bias": 0.0021734011825174093, "kernel": 0.06424261629581451}, "q_proj": {"bias": 0.0030707651749253273, "kernel": 0.06988008320331573}, "v_proj": {"bias": 0.003990429453551769, "kernel": 0.0750475525856018}}, "feed_forward": {"intermediate_dense": {"bias": 0.004226801451295614, "kernel": 0.09373506158590317}, "output_dense": {"bias": 0.001964298076927662, "kernel": 0.050748247653245926}}, "final_layer_norm": {"bias": 0.008214261382818222, "scale": 0.008091378957033157}, "layer_norm": {"bias": 0.009961582720279694, "scale": 0.014091897755861282}}, "37": {"attention": {"k_proj": {"bias": 6.950553506612778e-05, "kernel": 0.06806355714797974}, "out_proj": {"bias": 0.0018638172186911106, "kernel": 0.06148315966129303}, "q_proj": {"bias": 0.0027856058441102505, "kernel": 0.06524664908647537}, "v_proj": {"bias": 0.003444777335971594, "kernel": 0.07059472799301147}}, "feed_forward": {"intermediate_dense": {"bias": 0.003718752646818757, "kernel": 0.08477935194969177}, "output_dense": {"bias": 0.0016936655156314373, "kernel": 0.043044134974479675}}, "final_layer_norm": {"bias": 0.006910389289259911, "scale": 0.007826968096196651}, "layer_norm": {"bias": 0.008793966844677925, "scale": 0.012039574794471264}}, "38": {"attention": {"k_proj": {"bias": 6.380058766808361e-05, "kernel": 0.047973841428756714}, "out_proj": {"bias": 0.0016418255399912596, "kernel": 0.05211370810866356}, "q_proj": {"bias": 0.0019817340653389692, "kernel": 0.04425235092639923}, "v_proj": {"bias": 0.003118900116533041, "kernel": 0.06387682259082794}}, "feed_forward": {"intermediate_dense": {"bias": 0.0032114291097968817, "kernel": 0.07306227833032608}, "output_dense": {"bias": 0.0014597884146496654, "kernel": 0.040683649480342865}}, "final_layer_norm": {"bias": 0.006397427059710026, "scale": 0.006586809176951647}, "layer_norm": {"bias": 0.007732203230261803, "scale": 0.008158894255757332}}, "39": {"attention": {"k_proj": {"bias": 4.4050942960893735e-05, "kernel": 0.07367727905511856}, "out_proj": {"bias": 0.001328070997260511, "kernel": 0.04605572298169136}, "q_proj": {"bias": 0.0027621600311249495, "kernel": 0.06387172639369965}, "v_proj": {"bias": 0.0028065002989023924, "kernel": 0.06029339134693146}}, "feed_forward": {"intermediate_dense": {"bias": 0.002613849239423871, "kernel": 0.06034219264984131}, "output_dense": {"bias": 0.001167704933322966, "kernel": 0.03533755987882614}}, "final_layer_norm": {"bias": 0.0056147826835513115, "scale": 0.006254551466554403}, "layer_norm": {"bias": 0.008078474551439285, "scale": 0.009307093918323517}}, "4": {"attention": {"k_proj": {"bias": 0.00011861868551932275, "kernel": 0.05949955806136131}, "out_proj": {"bias": 0.010903699323534966, "kernel": 0.2538127303123474}, "q_proj": {"bias": 0.004336205311119556, "kernel": 0.06282104551792145}, "v_proj": {"bias": 0.016858994960784912, "kernel": 0.2015695869922638}}, "feed_forward": {"intermediate_dense": {"bias": 0.014640497043728828, "kernel": 0.22741298377513885}, "output_dense": {"bias": 0.010102435946464539, "kernel": 0.18625867366790771}}, "final_layer_norm": {"bias": 0.025719614699482918, "scale": 0.02265121601521969}, "layer_norm": {"bias": 0.03011525422334671, "scale": 0.02634708769619465}}, "40": {"attention": {"k_proj": {"bias": 3.653847670648247e-05, "kernel": 0.041202448308467865}, "out_proj": {"bias": 0.0011816287878900766, "kernel": 0.033319562673568726}, "q_proj": {"bias": 0.001703648129478097, "kernel": 0.03998265415430069}, "v_proj": {"bias": 0.0020524682477116585, "kernel": 0.03995149955153465}}, "feed_forward": {"intermediate_dense": {"bias": 0.0021124468185007572, "kernel": 0.04631158709526062}, "output_dense": {"bias": 0.0010702034924179316, "kernel": 0.029482468962669373}}, "final_layer_norm": {"bias": 0.004548347555100918, "scale": 0.00727574247866869}, "layer_norm": {"bias": 0.005058923736214638, "scale": 0.006590208038687706}}, "41": {"attention": {"k_proj": {"bias": 2.71731387329055e-05, "kernel": 0.04248828813433647}, "out_proj": {"bias": 0.0009862207807600498, "kernel": 0.04138129949569702}, "q_proj": {"bias": 0.0017071680631488562, "kernel": 0.04267487674951553}, "v_proj": {"bias": 0.002120822202414274, "kernel": 0.049663443118333817}}, "feed_forward": {"intermediate_dense": {"bias": 0.001695435494184494, "kernel": 0.04280196502804756}, "output_dense": {"bias": 0.0008836397901177406, "kernel": 0.030698874965310097}}, "final_layer_norm": {"bias": 0.003696385771036148, "scale": 0.00520673394203186}, "layer_norm": {"bias": 0.005573541857302189, "scale": 0.009067263454198837}}, "42": {"attention": {"k_proj": {"bias": 9.996898370445706e-06, "kernel": 0.01619807258248329}, "out_proj": {"bias": 0.000877649406902492, "kernel": 0.024295775219798088}, "q_proj": {"bias": 0.0007479889900423586, "kernel": 0.018460828810930252}, "v_proj": {"bias": 0.0012687454000115395, "kernel": 0.024550896137952805}}, "feed_forward": {"intermediate_dense": {"bias": 0.001578548108227551, "kernel": 0.04231848567724228}, "output_dense": {"bias": 0.0007989116129465401, "kernel": 0.029860224574804306}}, "final_layer_norm": {"bias": 0.0033415176440030336, "scale": 0.004648563452064991}, "layer_norm": {"bias": 0.002819265704602003, "scale": 0.0038991873152554035}}, "43": {"attention": {"k_proj": {"bias": 7.381008799711708e-06, "kernel": 0.009741116315126419}, "out_proj": {"bias": 0.0008183694444596767, "kernel": 0.018541457131505013}, "q_proj": {"bias": 0.0005745080416090786, "kernel": 0.012806322425603867}, "v_proj": {"bias": 0.001041700248606503, "kernel": 0.01789342798292637}}, "feed_forward": {"intermediate_dense": {"bias": 0.0015668238047510386, "kernel": 0.04644262045621872}, "output_dense": {"bias": 0.0007441275520250201, "kernel": 0.03125281631946564}}, "final_layer_norm": {"bias": 0.0034777228720486164, "scale": 0.0037668701261281967}, "layer_norm": {"bias": 0.0023245071060955524, "scale": 0.0032773285638540983}}, "44": {"attention": {"k_proj": {"bias": 6.937757007108303e-06, "kernel": 0.010700032114982605}, "out_proj": {"bias": 0.0007551060989499092, "kernel": 0.018443375825881958}, "q_proj": {"bias": 0.0008867518627084792, "kernel": 0.020168982446193695}, "v_proj": {"bias": 0.0010132507886737585, "kernel": 0.018830671906471252}}, "feed_forward": {"intermediate_dense": {"bias": 0.001374373328872025, "kernel": 0.044485267251729965}, "output_dense": {"bias": 0.0006765900179743767, "kernel": 0.03028588742017746}}, "final_layer_norm": {"bias": 0.003262903541326523, "scale": 0.00358769204467535}, "layer_norm": {"bias": 0.002687399508431554, "scale": 0.0032383990474045277}}, "45": {"attention": {"k_proj": {"bias": 1.1093410648754798e-05, "kernel": 0.005379953421652317}, "out_proj": {"bias": 0.0007007494568824768, "kernel": 0.018904201686382294}, "q_proj": {"bias": 0.0003387936158105731, "kernel": 0.007548992987722158}, "v_proj": {"bias": 0.000928386056330055, "kernel": 0.017278626561164856}}, "feed_forward": {"intermediate_dense": {"bias": 0.0012730576563626528, "kernel": 0.0406234934926033}, "output_dense": {"bias": 0.0006063497858121991, "kernel": 0.03363603353500366}}, "final_layer_norm": {"bias": 0.0036317934282124043, "scale": 0.0037220551166683435}, "layer_norm": {"bias": 0.0023685358464717865, "scale": 0.0020556054078042507}}, "46": {"attention": {"k_proj": {"bias": 3.145817754557356e-05, "kernel": 0.003510159905999899}, "out_proj": {"bias": 0.00062037433963269, "kernel": 0.021085944026708603}, "q_proj": {"bias": 0.0002703170757740736, "kernel": 0.005593193229287863}, "v_proj": {"bias": 0.0007606031140312552, "kernel": 0.01523013599216938}}, "feed_forward": {"intermediate_dense": {"bias": 0.0010667876340448856, "kernel": 0.028970055282115936}, "output_dense": {"bias": 0.0006037228740751743, "kernel": 0.03109053149819374}}, "final_layer_norm": {"bias": 0.0034019681625068188, "scale": 0.0032971068285405636}, "layer_norm": {"bias": 0.0022144035901874304, "scale": 0.0021386942826211452}}, "47": {"attention": {"k_proj": {"bias": 7.572821050416678e-05, "kernel": 0.004096582997590303}, "out_proj": {"bias": 0.000618669087998569, "kernel": 0.06384488940238953}, "q_proj": {"bias": 0.00016807969950605184, "kernel": 0.0034361861180514097}, "v_proj": {"bias": 0.0008409329457208514, "kernel": 0.01652412861585617}}, "feed_forward": {"intermediate_dense": {"bias": 0.000792563718277961, "kernel": 0.017461199313402176}, "output_dense": {"bias": 0.0005989559576846659, "kernel": 0.09638004750013351}}, "final_layer_norm": {"bias": 0.0062134405598044395, "scale": 0.006810440681874752}, "layer_norm": {"bias": 0.004194959066808224, "scale": 0.004303786437958479}}, "5": {"attention": {"k_proj": {"bias": 2.6883770260610618e-05, "kernel": 0.0624186210334301}, "out_proj": {"bias": 0.011139899492263794, "kernel": 0.13817574083805084}, "q_proj": {"bias": 0.004674529191106558, "kernel": 0.06611215323209763}, "v_proj": {"bias": 0.017080599442124367, "kernel": 0.12343898415565491}}, "feed_forward": {"intermediate_dense": {"bias": 0.014788923785090446, "kernel": 0.2172151654958725}, "output_dense": {"bias": 0.010435854084789753, "kernel": 0.18610242009162903}}, "final_layer_norm": {"bias": 0.02695547789335251, "scale": 0.023944713175296783}, "layer_norm": {"bias": 0.033194735646247864, "scale": 0.054335542023181915}}, "6": {"attention": {"k_proj": {"bias": 5.946035889792256e-05, "kernel": 0.0887044221162796}, "out_proj": {"bias": 0.010512596927583218, "kernel": 0.23351940512657166}, "q_proj": {"bias": 0.006024146918207407, "kernel": 0.08463111519813538}, "v_proj": {"bias": 0.016660811379551888, "kernel": 0.20662474632263184}}, "feed_forward": {"intermediate_dense": {"bias": 0.015289617702364922, "kernel": 0.2351919412612915}, "output_dense": {"bias": 0.009930520318448544, "kernel": 0.18425649404525757}}, "final_layer_norm": {"bias": 0.025580020621418953, "scale": 0.02367497608065605}, "layer_norm": {"bias": 0.03188800439238548, "scale": 0.03837102651596069}}, "7": {"attention": {"k_proj": {"bias": 0.00011192046076757833, "kernel": 0.09136287868022919}, "out_proj": {"bias": 0.010412187315523624, "kernel": 0.2562747895717621}, "q_proj": {"bias": 0.005691881757229567, "kernel": 0.08856040239334106}, "v_proj": {"bias": 0.017323974519968033, "kernel": 0.23546917736530304}}, "feed_forward": {"intermediate_dense": {"bias": 0.015341667458415031, "kernel": 0.245836079120636}, "output_dense": {"bias": 0.009862898848950863, "kernel": 0.19427022337913513}}, "final_layer_norm": {"bias": 0.027472414076328278, "scale": 0.022995298728346825}, "layer_norm": {"bias": 0.03300042822957039, "scale": 0.037924688309431076}}, "8": {"attention": {"k_proj": {"bias": 9.741022950038314e-05, "kernel": 0.09417441487312317}, "out_proj": {"bias": 0.01020738109946251, "kernel": 0.21754060685634613}, "q_proj": {"bias": 0.007090041413903236, "kernel": 0.09902448952198029}, "v_proj": {"bias": 0.016644656658172607, "kernel": 0.2003098428249359}}, "feed_forward": {"intermediate_dense": {"bias": 0.015562525019049644, "kernel": 0.25062885880470276}, "output_dense": {"bias": 0.009757128544151783, "kernel": 0.19549694657325745}}, "final_layer_norm": {"bias": 0.027269396930933, "scale": 0.028064392507076263}, "layer_norm": {"bias": 0.03486243635416031, "scale": 0.03840157017111778}}, "9": {"attention": {"k_proj": {"bias": 0.00010006569209508598, "kernel": 0.1382577121257782}, "out_proj": {"bias": 0.009580248035490513, "kernel": 0.308486670255661}, "q_proj": {"bias": 0.008061045780777931, "kernel": 0.12973228096961975}, "v_proj": {"bias": 0.015853915363550186, "kernel": 0.29045191407203674}}, "feed_forward": {"intermediate_dense": {"bias": 0.014156410470604897, "kernel": 0.2397565096616745}, "output_dense": {"bias": 0.009250838309526443, "kernel": 0.18961884081363678}}, "final_layer_norm": {"bias": 0.026453740894794464, "scale": 0.024667374789714813}, "layer_norm": {"bias": 0.031355082988739014, "scale": 0.03998243808746338}}}, "pos_conv_embed": {"conv": {"bias": 0.022986222058534622, "weight_g": 0.04244992509484291, "weight_v": 0.25040704011917114}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.07150723785161972, "scale": 0.11794900894165039}, "projection": {"bias": 0.03498292341828346, "kernel": 0.76469486951828}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.7813382744789124, "kernel": 55.68170166015625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 59.61796569824219, "scale": 74.11318969726562}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.3701832890510559, "kernel": 27.534406661987305}, "out_proj": {"bias": 1.6470978260040283, "kernel": 26.14641761779785}, "q_proj": {"bias": 1.5328073501586914, "kernel": 27.812232971191406}, "v_proj": {"bias": 0.44818246364593506, "kernel": 26.558326721191406}}, "feed_forward": {"intermediate_dense": {"bias": 1.98408842086792, "kernel": 100.66358947753906}, "output_dense": {"bias": 1.1168419122695923, "kernel": 96.7653579711914}}, "final_layer_norm": {"bias": 1.33500337600708, "scale": 19.857891082763672}, "layer_norm": {"bias": 2.9222915172576904, "scale": 15.399209976196289}}, "1": {"attention": {"k_proj": {"bias": 0.37601765990257263, "kernel": 41.01239013671875}, "out_proj": {"bias": 1.365553617477417, "kernel": 43.370052337646484}, "q_proj": {"bias": 3.0928053855895996, "kernel": 41.055511474609375}, "v_proj": {"bias": 0.29229462146759033, "kernel": 41.610923767089844}}, "feed_forward": {"intermediate_dense": {"bias": 1.988027811050415, "kernel": 98.74156951904297}, "output_dense": {"bias": 0.8528561592102051, "kernel": 87.83381652832031}}, "final_layer_norm": {"bias": 1.3606328964233398, "scale": 19.08454132080078}, "layer_norm": {"bias": 1.932058334350586, "scale": 17.76154327392578}}, "10": {"attention": {"k_proj": {"bias": 0.412264347076416, "kernel": 49.44655227661133}, "out_proj": {"bias": 1.312946081161499, "kernel": 52.270050048828125}, "q_proj": {"bias": 2.484537124633789, "kernel": 49.528587341308594}, "v_proj": {"bias": 0.3413662910461426, "kernel": 52.343936920166016}}, "feed_forward": {"intermediate_dense": {"bias": 1.9745526313781738, "kernel": 102.7016830444336}, "output_dense": {"bias": 0.5953460931777954, "kernel": 95.81098937988281}}, "final_layer_norm": {"bias": 2.376117706298828, "scale": 20.812576293945312}, "layer_norm": {"bias": 1.806325078010559, "scale": 21.430335998535156}}, "11": {"attention": {"k_proj": {"bias": 0.4463092088699341, "kernel": 49.3727912902832}, "out_proj": {"bias": 1.150918960571289, "kernel": 51.94880294799805}, "q_proj": {"bias": 2.54514741897583, "kernel": 49.2027587890625}, "v_proj": {"bias": 0.40985018014907837, "kernel": 52.26004409790039}}, "feed_forward": {"intermediate_dense": {"bias": 2.0194969177246094, "kernel": 103.55790710449219}, "output_dense": {"bias": 0.5710991621017456, "kernel": 97.53599548339844}}, "final_layer_norm": {"bias": 2.365635871887207, "scale": 20.91950225830078}, "layer_norm": {"bias": 1.7795329093933105, "scale": 22.015670776367188}}, "12": {"attention": {"k_proj": {"bias": 0.4312525689601898, "kernel": 50.0694580078125}, "out_proj": {"bias": 1.1258351802825928, "kernel": 51.988014221191406}, "q_proj": {"bias": 2.409198760986328, "kernel": 49.926300048828125}, "v_proj": {"bias": 0.4022805690765381, "kernel": 52.2958984375}}, "feed_forward": {"intermediate_dense": {"bias": 2.0547420978546143, "kernel": 104.54605865478516}, "output_dense": {"bias": 0.5548651218414307, "kernel": 99.26763916015625}}, "final_layer_norm": {"bias": 2.2928576469421387, "scale": 20.85626220703125}, "layer_norm": {"bias": 1.8584762811660767, "scale": 22.473596572875977}}, "13": {"attention": {"k_proj": {"bias": 0.44326990842819214, "kernel": 51.76347351074219}, "out_proj": {"bias": 1.1270570755004883, "kernel": 51.86771774291992}, "q_proj": {"bias": 2.3591299057006836, "kernel": 51.751495361328125}, "v_proj": {"bias": 0.39057764410972595, "kernel": 51.926971435546875}}, "feed_forward": {"intermediate_dense": {"bias": 2.09421968460083, "kernel": 105.3058090209961}, "output_dense": {"bias": 0.5719174742698669, "kernel": 99.86944580078125}}, "final_layer_norm": {"bias": 2.231492519378662, "scale": 21.027198791503906}, "layer_norm": {"bias": 1.9996953010559082, "scale": 22.845252990722656}}, "14": {"attention": {"k_proj": {"bias": 0.43609967827796936, "kernel": 51.92134475708008}, "out_proj": {"bias": 1.2679402828216553, "kernel": 49.76212692260742}, "q_proj": {"bias": 2.494476795196533, "kernel": 52.04914855957031}, "v_proj": {"bias": 0.3679965138435364, "kernel": 49.262184143066406}}, "feed_forward": {"intermediate_dense": {"bias": 2.1335108280181885, "kernel": 105.94230651855469}, "output_dense": {"bias": 0.6064561009407043, "kernel": 101.24674987792969}}, "final_layer_norm": {"bias": 2.2754456996917725, "scale": 21.145736694335938}, "layer_norm": {"bias": 2.1294503211975098, "scale": 22.58496856689453}}, "15": {"attention": {"k_proj": {"bias": 0.4591600298881531, "kernel": 51.9299430847168}, "out_proj": {"bias": 1.3749371767044067, "kernel": 50.94367218017578}, "q_proj": {"bias": 2.587465763092041, "kernel": 52.09908676147461}, "v_proj": {"bias": 0.45972001552581787, "kernel": 50.582244873046875}}, "feed_forward": {"intermediate_dense": {"bias": 2.132887840270996, "kernel": 105.57237243652344}, "output_dense": {"bias": 0.7701091170310974, "kernel": 101.93905639648438}}, "final_layer_norm": {"bias": 2.326786518096924, "scale": 21.19296646118164}, "layer_norm": {"bias": 2.4140264987945557, "scale": 23.52659034729004}}, "16": {"attention": {"k_proj": {"bias": 0.40078312158584595, "kernel": 51.77173614501953}, "out_proj": {"bias": 1.2750227451324463, "kernel": 50.133506774902344}, "q_proj": {"bias": 2.6670660972595215, "kernel": 51.7573356628418}, "v_proj": {"bias": 0.37677696347236633, "kernel": 49.779869079589844}}, "feed_forward": {"intermediate_dense": {"bias": 2.1093995571136475, "kernel": 106.10017395019531}, "output_dense": {"bias": 0.7859854698181152, "kernel": 102.66400146484375}}, "final_layer_norm": {"bias": 2.3371658325195312, "scale": 21.5831298828125}, "layer_norm": {"bias": 2.2830886840820312, "scale": 22.16848373413086}}, "17": {"attention": {"k_proj": {"bias": 0.39663973450660706, "kernel": 51.728233337402344}, "out_proj": {"bias": 1.2150659561157227, "kernel": 49.45488739013672}, "q_proj": {"bias": 2.7145652770996094, "kernel": 51.818031311035156}, "v_proj": {"bias": 0.42649292945861816, "kernel": 49.118465423583984}}, "feed_forward": {"intermediate_dense": {"bias": 2.1102144718170166, "kernel": 107.14662170410156}, "output_dense": {"bias": 0.821881890296936, "kernel": 103.0622787475586}}, "final_layer_norm": {"bias": 2.3840279579162598, "scale": 22.070194244384766}, "layer_norm": {"bias": 2.222987651824951, "scale": 21.220136642456055}}, "18": {"attention": {"k_proj": {"bias": 0.44067007303237915, "kernel": 52.415435791015625}, "out_proj": {"bias": 1.3447682857513428, "kernel": 50.4912109375}, "q_proj": {"bias": 2.6146581172943115, "kernel": 52.795860290527344}, "v_proj": {"bias": 0.45181968808174133, "kernel": 50.00122833251953}}, "feed_forward": {"intermediate_dense": {"bias": 2.1441776752471924, "kernel": 107.41145324707031}, "output_dense": {"bias": 0.9481906890869141, "kernel": 104.7232666015625}}, "final_layer_norm": {"bias": 2.5388097763061523, "scale": 22.15188980102539}, "layer_norm": {"bias": 2.424826145172119, "scale": 23.585994720458984}}, "19": {"attention": {"k_proj": {"bias": 0.3818454146385193, "kernel": 51.51025390625}, "out_proj": {"bias": 1.3302791118621826, "kernel": 50.099727630615234}, "q_proj": {"bias": 2.9301462173461914, "kernel": 51.864768981933594}, "v_proj": {"bias": 0.408627450466156, "kernel": 49.380149841308594}}, "feed_forward": {"intermediate_dense": {"bias": 2.191500186920166, "kernel": 107.95045471191406}, "output_dense": {"bias": 1.024780511856079, "kernel": 105.64892578125}}, "final_layer_norm": {"bias": 2.4928956031799316, "scale": 22.50541877746582}, "layer_norm": {"bias": 2.288670539855957, "scale": 22.31869125366211}}, "2": {"attention": {"k_proj": {"bias": 0.4549238085746765, "kernel": 47.772315979003906}, "out_proj": {"bias": 1.2571773529052734, "kernel": 45.969146728515625}, "q_proj": {"bias": 3.2510793209075928, "kernel": 47.61619567871094}, "v_proj": {"bias": 0.3394175171852112, "kernel": 45.72221374511719}}, "feed_forward": {"intermediate_dense": {"bias": 1.9738647937774658, "kernel": 103.3251953125}, "output_dense": {"bias": 0.7398295998573303, "kernel": 91.11064147949219}}, "final_layer_norm": {"bias": 1.5427566766738892, "scale": 21.560688018798828}, "layer_norm": {"bias": 1.7084418535232544, "scale": 20.85301971435547}}, "20": {"attention": {"k_proj": {"bias": 0.4066943824291229, "kernel": 51.604251861572266}, "out_proj": {"bias": 1.3597817420959473, "kernel": 49.454742431640625}, "q_proj": {"bias": 2.849797010421753, "kernel": 52.22339630126953}, "v_proj": {"bias": 0.3622246980667114, "kernel": 48.43778991699219}}, "feed_forward": {"intermediate_dense": {"bias": 2.1723599433898926, "kernel": 109.17230224609375}, "output_dense": {"bias": 1.1388349533081055, "kernel": 106.40357971191406}}, "final_layer_norm": {"bias": 2.4343795776367188, "scale": 23.432008743286133}, "layer_norm": {"bias": 2.231492042541504, "scale": 22.23037338256836}}, "21": {"attention": {"k_proj": {"bias": 0.4163251221179962, "kernel": 51.941925048828125}, "out_proj": {"bias": 1.4036574363708496, "kernel": 49.50983428955078}, "q_proj": {"bias": 2.7688515186309814, "kernel": 52.67023849487305}, "v_proj": {"bias": 0.41047605872154236, "kernel": 48.6480712890625}}, "feed_forward": {"intermediate_dense": {"bias": 2.217552423477173, "kernel": 109.51356506347656}, "output_dense": {"bias": 1.2531100511550903, "kernel": 106.88032531738281}}, "final_layer_norm": {"bias": 2.463512420654297, "scale": 23.17563247680664}, "layer_norm": {"bias": 2.2788496017456055, "scale": 22.23470687866211}}, "22": {"attention": {"k_proj": {"bias": 0.45348384976387024, "kernel": 52.54527282714844}, "out_proj": {"bias": 1.3491016626358032, "kernel": 49.532562255859375}, "q_proj": {"bias": 2.810800313949585, "kernel": 52.869224548339844}, "v_proj": {"bias": 0.3972615599632263, "kernel": 49.3331298828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.1617891788482666, "kernel": 109.95318603515625}, "output_dense": {"bias": 1.3074685335159302, "kernel": 106.38297271728516}}, "final_layer_norm": {"bias": 2.363884449005127, "scale": 22.684080123901367}, "layer_norm": {"bias": 2.3313965797424316, "scale": 21.546348571777344}}, "23": {"attention": {"k_proj": {"bias": 0.4928063750267029, "kernel": 53.47571563720703}, "out_proj": {"bias": 1.5644241571426392, "kernel": 50.986122131347656}, "q_proj": {"bias": 2.7066195011138916, "kernel": 53.581642150878906}, "v_proj": {"bias": 0.5814877152442932, "kernel": 51.54754638671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.132040023803711, "kernel": 109.86238098144531}, "output_dense": {"bias": 1.2768850326538086, "kernel": 107.37675476074219}}, "final_layer_norm": {"bias": 2.768141269683838, "scale": 22.887840270996094}, "layer_norm": {"bias": 2.824951171875, "scale": 23.372961044311523}}, "24": {"attention": {"k_proj": {"bias": 0.46051931381225586, "kernel": 52.42326354980469}, "out_proj": {"bias": 1.607046365737915, "kernel": 52.502601623535156}, "q_proj": {"bias": 2.8280816078186035, "kernel": 52.404296875}, "v_proj": {"bias": 0.5424301624298096, "kernel": 52.510433197021484}}, "feed_forward": {"intermediate_dense": {"bias": 2.2367987632751465, "kernel": 109.34848022460938}, "output_dense": {"bias": 1.30157470703125, "kernel": 110.298828125}}, "final_layer_norm": {"bias": 2.8384151458740234, "scale": 22.96453857421875}, "layer_norm": {"bias": 2.5624277591705322, "scale": 22.984024047851562}}, "25": {"attention": {"k_proj": {"bias": 0.4248366355895996, "kernel": 52.728973388671875}, "out_proj": {"bias": 1.363987922668457, "kernel": 50.57943344116211}, "q_proj": {"bias": 2.9338529109954834, "kernel": 52.54735565185547}, "v_proj": {"bias": 0.6403459310531616, "kernel": 51.08721923828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.1368496417999268, "kernel": 109.69841003417969}, "output_dense": {"bias": 1.1018216609954834, "kernel": 110.26852416992188}}, "final_layer_norm": {"bias": 2.5759782791137695, "scale": 23.49481201171875}, "layer_norm": {"bias": 2.6831774711608887, "scale": 21.883270263671875}}, "26": {"attention": {"k_proj": {"bias": 0.4836840033531189, "kernel": 53.01679992675781}, "out_proj": {"bias": 1.2437195777893066, "kernel": 51.36989212036133}, "q_proj": {"bias": 2.9436087608337402, "kernel": 52.80816650390625}, "v_proj": {"bias": 0.5066676139831543, "kernel": 52.00371551513672}}, "feed_forward": {"intermediate_dense": {"bias": 2.2764625549316406, "kernel": 109.44453430175781}, "output_dense": {"bias": 1.0910954475402832, "kernel": 107.40631866455078}}, "final_layer_norm": {"bias": 2.1941800117492676, "scale": 22.433135986328125}, "layer_norm": {"bias": 2.4980502128601074, "scale": 22.190664291381836}}, "27": {"attention": {"k_proj": {"bias": 0.5808664560317993, "kernel": 53.7679443359375}, "out_proj": {"bias": 1.5447583198547363, "kernel": 52.95677185058594}, "q_proj": {"bias": 2.703429698944092, "kernel": 53.694732666015625}, "v_proj": {"bias": 0.6754124164581299, "kernel": 53.38676834106445}}, "feed_forward": {"intermediate_dense": {"bias": 2.4051074981689453, "kernel": 107.8687744140625}, "output_dense": {"bias": 0.9484013319015503, "kernel": 107.16940307617188}}, "final_layer_norm": {"bias": 2.5260066986083984, "scale": 21.88892364501953}, "layer_norm": {"bias": 2.614506483078003, "scale": 23.32402992248535}}, "28": {"attention": {"k_proj": {"bias": 0.5901625752449036, "kernel": 54.48159408569336}, "out_proj": {"bias": 1.5366472005844116, "kernel": 53.31403350830078}, "q_proj": {"bias": 2.9472250938415527, "kernel": 54.17335510253906}, "v_proj": {"bias": 0.5132970809936523, "kernel": 53.758819580078125}}, "feed_forward": {"intermediate_dense": {"bias": 2.347456455230713, "kernel": 107.8719482421875}, "output_dense": {"bias": 0.8225339651107788, "kernel": 109.16556549072266}}, "final_layer_norm": {"bias": 2.4250011444091797, "scale": 22.337326049804688}, "layer_norm": {"bias": 2.0917792320251465, "scale": 23.993865966796875}}, "29": {"attention": {"k_proj": {"bias": 0.4677562713623047, "kernel": 51.1196174621582}, "out_proj": {"bias": 1.5024232864379883, "kernel": 55.68492889404297}, "q_proj": {"bias": 2.80938982963562, "kernel": 51.00201416015625}, "v_proj": {"bias": 0.47646182775497437, "kernel": 55.70256805419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.2973735332489014, "kernel": 108.0084228515625}, "output_dense": {"bias": 0.9599496722221375, "kernel": 113.12579345703125}}, "final_layer_norm": {"bias": 2.5981667041778564, "scale": 23.459945678710938}, "layer_norm": {"bias": 2.2456541061401367, "scale": 25.399646759033203}}, "3": {"attention": {"k_proj": {"bias": 0.45023608207702637, "kernel": 52.03062057495117}, "out_proj": {"bias": 1.4256000518798828, "kernel": 48.607704162597656}, "q_proj": {"bias": 2.8558900356292725, "kernel": 52.31114959716797}, "v_proj": {"bias": 0.3248745799064636, "kernel": 48.767791748046875}}, "feed_forward": {"intermediate_dense": {"bias": 1.966486930847168, "kernel": 104.83381652832031}, "output_dense": {"bias": 0.6985666155815125, "kernel": 94.07772827148438}}, "final_layer_norm": {"bias": 1.80990731716156, "scale": 21.664508819580078}, "layer_norm": {"bias": 1.9018754959106445, "scale": 22.739166259765625}}, "30": {"attention": {"k_proj": {"bias": 0.5022295713424683, "kernel": 52.82470703125}, "out_proj": {"bias": 1.302648663520813, "kernel": 52.05290222167969}, "q_proj": {"bias": 2.9069385528564453, "kernel": 52.9173583984375}, "v_proj": {"bias": 0.49356698989868164, "kernel": 52.492820739746094}}, "feed_forward": {"intermediate_dense": {"bias": 2.240218162536621, "kernel": 108.17642974853516}, "output_dense": {"bias": 0.9139082431793213, "kernel": 112.08816528320312}}, "final_layer_norm": {"bias": 2.4933857917785645, "scale": 24.492006301879883}, "layer_norm": {"bias": 2.3173294067382812, "scale": 24.931093215942383}}, "31": {"attention": {"k_proj": {"bias": 0.5415275692939758, "kernel": 51.23994064331055}, "out_proj": {"bias": 1.233104944229126, "kernel": 52.19916534423828}, "q_proj": {"bias": 2.658158779144287, "kernel": 51.345314025878906}, "v_proj": {"bias": 0.5471001863479614, "kernel": 52.43183517456055}}, "feed_forward": {"intermediate_dense": {"bias": 2.309779644012451, "kernel": 106.72549438476562}, "output_dense": {"bias": 1.0889759063720703, "kernel": 109.24472045898438}}, "final_layer_norm": {"bias": 2.296581268310547, "scale": 24.312225341796875}, "layer_norm": {"bias": 2.3432631492614746, "scale": 24.59036636352539}}, "32": {"attention": {"k_proj": {"bias": 0.470304012298584, "kernel": 50.39829635620117}, "out_proj": {"bias": 1.2450408935546875, "kernel": 51.573692321777344}, "q_proj": {"bias": 2.845055103302002, "kernel": 50.347450256347656}, "v_proj": {"bias": 0.4195507764816284, "kernel": 51.971351623535156}}, "feed_forward": {"intermediate_dense": {"bias": 2.256967544555664, "kernel": 105.32807922363281}, "output_dense": {"bias": 1.1466200351715088, "kernel": 108.35316467285156}}, "final_layer_norm": {"bias": 2.315363883972168, "scale": 24.518821716308594}, "layer_norm": {"bias": 2.417573928833008, "scale": 24.991804122924805}}, "33": {"attention": {"k_proj": {"bias": 0.4837779402732849, "kernel": 50.2818603515625}, "out_proj": {"bias": 1.280759334564209, "kernel": 51.304786682128906}, "q_proj": {"bias": 2.9976298809051514, "kernel": 50.253135681152344}, "v_proj": {"bias": 0.4415435791015625, "kernel": 51.70958709716797}}, "feed_forward": {"intermediate_dense": {"bias": 2.279780387878418, "kernel": 103.67518615722656}, "output_dense": {"bias": 1.1753971576690674, "kernel": 106.81249237060547}}, "final_layer_norm": {"bias": 2.2574286460876465, "scale": 24.20797348022461}, "layer_norm": {"bias": 2.587249517440796, "scale": 25.068130493164062}}, "34": {"attention": {"k_proj": {"bias": 0.4537648558616638, "kernel": 49.26386260986328}, "out_proj": {"bias": 1.526426076889038, "kernel": 52.460693359375}, "q_proj": {"bias": 2.9134769439697266, "kernel": 49.267578125}, "v_proj": {"bias": 0.40280866622924805, "kernel": 52.5334358215332}}, "feed_forward": {"intermediate_dense": {"bias": 2.374392032623291, "kernel": 102.2095947265625}, "output_dense": {"bias": 1.1227495670318604, "kernel": 105.71914672851562}}, "final_layer_norm": {"bias": 2.202498435974121, "scale": 23.641082763671875}, "layer_norm": {"bias": 2.616096019744873, "scale": 25.498010635375977}}, "35": {"attention": {"k_proj": {"bias": 0.5340859889984131, "kernel": 51.04802703857422}, "out_proj": {"bias": 1.490316390991211, "kernel": 51.14915084838867}, "q_proj": {"bias": 2.5735840797424316, "kernel": 51.32292175292969}, "v_proj": {"bias": 0.489106148481369, "kernel": 51.175716400146484}}, "feed_forward": {"intermediate_dense": {"bias": 2.5023369789123535, "kernel": 100.7518081665039}, "output_dense": {"bias": 1.0255975723266602, "kernel": 104.24783325195312}}, "final_layer_norm": {"bias": 2.295048952102661, "scale": 23.61989974975586}, "layer_norm": {"bias": 2.5003116130828857, "scale": 26.09359359741211}}, "36": {"attention": {"k_proj": {"bias": 0.4473268389701843, "kernel": 48.31817626953125}, "out_proj": {"bias": 1.5180878639221191, "kernel": 52.26365661621094}, "q_proj": {"bias": 2.612337112426758, "kernel": 48.23484802246094}, "v_proj": {"bias": 0.3936375379562378, "kernel": 52.66379165649414}}, "feed_forward": {"intermediate_dense": {"bias": 2.366895914077759, "kernel": 99.60334777832031}, "output_dense": {"bias": 1.0262877941131592, "kernel": 103.67841339111328}}, "final_layer_norm": {"bias": 2.044355630874634, "scale": 24.135536193847656}, "layer_norm": {"bias": 2.289370059967041, "scale": 25.638019561767578}}, "37": {"attention": {"k_proj": {"bias": 0.6238871812820435, "kernel": 47.2994270324707}, "out_proj": {"bias": 1.7608067989349365, "kernel": 52.18843460083008}, "q_proj": {"bias": 2.3821022510528564, "kernel": 47.31465530395508}, "v_proj": {"bias": 0.38558459281921387, "kernel": 52.32508087158203}}, "feed_forward": {"intermediate_dense": {"bias": 2.269535779953003, "kernel": 98.56369018554688}, "output_dense": {"bias": 1.0089890956878662, "kernel": 103.12788391113281}}, "final_layer_norm": {"bias": 1.790808916091919, "scale": 24.50014877319336}, "layer_norm": {"bias": 2.2385404109954834, "scale": 25.562536239624023}}, "38": {"attention": {"k_proj": {"bias": 0.7226560115814209, "kernel": 45.4516716003418}, "out_proj": {"bias": 1.4467517137527466, "kernel": 51.49433898925781}, "q_proj": {"bias": 2.262176990509033, "kernel": 45.453369140625}, "v_proj": {"bias": 0.42999938130378723, "kernel": 51.56456756591797}}, "feed_forward": {"intermediate_dense": {"bias": 2.201237678527832, "kernel": 96.4322509765625}, "output_dense": {"bias": 0.9820237159729004, "kernel": 101.30199432373047}}, "final_layer_norm": {"bias": 1.7823377847671509, "scale": 25.21481704711914}, "layer_norm": {"bias": 2.4086427688598633, "scale": 26.425819396972656}}, "39": {"attention": {"k_proj": {"bias": 0.7209489941596985, "kernel": 45.246360778808594}, "out_proj": {"bias": 1.7099244594573975, "kernel": 51.328575134277344}, "q_proj": {"bias": 2.1155166625976562, "kernel": 45.53589630126953}, "v_proj": {"bias": 0.42500513792037964, "kernel": 51.27678680419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.1776342391967773, "kernel": 94.23681640625}, "output_dense": {"bias": 1.046210765838623, "kernel": 101.17521667480469}}, "final_layer_norm": {"bias": 1.7543823719024658, "scale": 25.785675048828125}, "layer_norm": {"bias": 2.331770896911621, "scale": 26.942829132080078}}, "4": {"attention": {"k_proj": {"bias": 0.44543692469596863, "kernel": 54.622467041015625}, "out_proj": {"bias": 1.6522142887115479, "kernel": 50.16417694091797}, "q_proj": {"bias": 2.6155476570129395, "kernel": 54.929893493652344}, "v_proj": {"bias": 0.348906010389328, "kernel": 50.31928634643555}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531652927398682, "kernel": 104.50743103027344}, "output_dense": {"bias": 0.8574683666229248, "kernel": 95.54345703125}}, "final_layer_norm": {"bias": 1.9918160438537598, "scale": 21.200042724609375}, "layer_norm": {"bias": 2.054326295852661, "scale": 23.6204891204834}}, "40": {"attention": {"k_proj": {"bias": 0.6591145396232605, "kernel": 44.19691467285156}, "out_proj": {"bias": 1.6251122951507568, "kernel": 49.55665588378906}, "q_proj": {"bias": 1.968193769454956, "kernel": 44.890804290771484}, "v_proj": {"bias": 0.4584900736808777, "kernel": 49.23584747314453}}, "feed_forward": {"intermediate_dense": {"bias": 2.0335679054260254, "kernel": 92.16676330566406}, "output_dense": {"bias": 1.088383436203003, "kernel": 98.40609741210938}}, "final_layer_norm": {"bias": 1.785752296447754, "scale": 25.04298973083496}, "layer_norm": {"bias": 2.2745361328125, "scale": 26.4086856842041}}, "41": {"attention": {"k_proj": {"bias": 1.7136410474777222, "kernel": 41.96782302856445}, "out_proj": {"bias": 1.3790074586868286, "kernel": 51.25641632080078}, "q_proj": {"bias": 1.714700698852539, "kernel": 42.56212615966797}, "v_proj": {"bias": 0.4701315760612488, "kernel": 50.3696403503418}}, "feed_forward": {"intermediate_dense": {"bias": 2.110346555709839, "kernel": 88.92317199707031}, "output_dense": {"bias": 1.1444766521453857, "kernel": 97.37318420410156}}, "final_layer_norm": {"bias": 2.2391064167022705, "scale": 28.50811004638672}, "layer_norm": {"bias": 2.224137783050537, "scale": 28.247453689575195}}, "42": {"attention": {"k_proj": {"bias": 0.8601677417755127, "kernel": 38.31172180175781}, "out_proj": {"bias": 1.4422290325164795, "kernel": 45.07737731933594}, "q_proj": {"bias": 1.5497033596038818, "kernel": 39.52315902709961}, "v_proj": {"bias": 0.6942248344421387, "kernel": 43.49156951904297}}, "feed_forward": {"intermediate_dense": {"bias": 1.909865379333496, "kernel": 88.00902557373047}, "output_dense": {"bias": 1.1991633176803589, "kernel": 95.75923156738281}}, "final_layer_norm": {"bias": 1.922487497329712, "scale": 29.81951904296875}, "layer_norm": {"bias": 1.6743782758712769, "scale": 26.81166648864746}}, "43": {"attention": {"k_proj": {"bias": 1.2469921112060547, "kernel": 34.693359375}, "out_proj": {"bias": 1.4187889099121094, "kernel": 41.365543365478516}, "q_proj": {"bias": 1.377383828163147, "kernel": 35.388580322265625}, "v_proj": {"bias": 0.5821226835250854, "kernel": 39.29367446899414}}, "feed_forward": {"intermediate_dense": {"bias": 1.8935291767120361, "kernel": 87.0740966796875}, "output_dense": {"bias": 0.945226788520813, "kernel": 93.7410888671875}}, "final_layer_norm": {"bias": 2.000356674194336, "scale": 32.051448822021484}, "layer_norm": {"bias": 1.7934319972991943, "scale": 25.132694244384766}}, "44": {"attention": {"k_proj": {"bias": 2.518559455871582, "kernel": 35.16217803955078}, "out_proj": {"bias": 1.1713123321533203, "kernel": 45.02171325683594}, "q_proj": {"bias": 1.3182837963104248, "kernel": 35.58189010620117}, "v_proj": {"bias": 0.3941196799278259, "kernel": 44.14886474609375}}, "feed_forward": {"intermediate_dense": {"bias": 1.9490540027618408, "kernel": 86.07324981689453}, "output_dense": {"bias": 0.8624283075332642, "kernel": 91.57963562011719}}, "final_layer_norm": {"bias": 2.0508077144622803, "scale": 34.28636932373047}, "layer_norm": {"bias": 1.681735634803772, "scale": 25.142786026000977}}, "45": {"attention": {"k_proj": {"bias": 2.081468343734741, "kernel": 34.86262512207031}, "out_proj": {"bias": 1.0405850410461426, "kernel": 48.60301208496094}, "q_proj": {"bias": 1.4064133167266846, "kernel": 35.032684326171875}, "v_proj": {"bias": 0.4274117350578308, "kernel": 48.770912170410156}}, "feed_forward": {"intermediate_dense": {"bias": 2.0166711807250977, "kernel": 82.92997741699219}, "output_dense": {"bias": 0.9761476516723633, "kernel": 87.24011993408203}}, "final_layer_norm": {"bias": 1.9100477695465088, "scale": 33.14140319824219}, "layer_norm": {"bias": 1.567474126815796, "scale": 23.78523826599121}}, "46": {"attention": {"k_proj": {"bias": 1.5659271478652954, "kernel": 35.87921142578125}, "out_proj": {"bias": 0.819352388381958, "kernel": 51.16535949707031}, "q_proj": {"bias": 1.563781499862671, "kernel": 36.190025329589844}, "v_proj": {"bias": 0.4101935625076294, "kernel": 51.894287109375}}, "feed_forward": {"intermediate_dense": {"bias": 2.012660026550293, "kernel": 77.47468566894531}, "output_dense": {"bias": 1.1406304836273193, "kernel": 77.72789764404297}}, "final_layer_norm": {"bias": 1.807024359703064, "scale": 28.70673370361328}, "layer_norm": {"bias": 1.400099515914917, "scale": 22.81011962890625}}, "47": {"attention": {"k_proj": {"bias": 0.6174112558364868, "kernel": 38.679481506347656}, "out_proj": {"bias": 0.6763718128204346, "kernel": 46.45091247558594}, "q_proj": {"bias": 1.7086448669433594, "kernel": 39.42691421508789}, "v_proj": {"bias": 0.4947702884674072, "kernel": 47.612117767333984}}, "feed_forward": {"intermediate_dense": {"bias": 1.9924734830856323, "kernel": 75.4644775390625}, "output_dense": {"bias": 0.6346121430397034, "kernel": 72.82086181640625}}, "final_layer_norm": {"bias": 1.1900091171264648, "scale": 23.649497985839844}, "layer_norm": {"bias": 1.2531957626342773, "scale": 20.66447639465332}}, "5": {"attention": {"k_proj": {"bias": 0.4255741834640503, "kernel": 50.193519592285156}, "out_proj": {"bias": 1.603766918182373, "kernel": 51.21357727050781}, "q_proj": {"bias": 2.7516400814056396, "kernel": 50.37419891357422}, "v_proj": {"bias": 0.3344327211380005, "kernel": 51.71560287475586}}, "feed_forward": {"intermediate_dense": {"bias": 1.888875961303711, "kernel": 104.60415649414062}, "output_dense": {"bias": 0.8974951505661011, "kernel": 94.77175903320312}}, "final_layer_norm": {"bias": 2.1963815689086914, "scale": 21.379741668701172}, "layer_norm": {"bias": 2.0436153411865234, "scale": 22.437408447265625}}, "6": {"attention": {"k_proj": {"bias": 0.4842904806137085, "kernel": 51.87656021118164}, "out_proj": {"bias": 1.5924865007400513, "kernel": 50.83064270019531}, "q_proj": {"bias": 2.7890353202819824, "kernel": 52.35090637207031}, "v_proj": {"bias": 0.324540376663208, "kernel": 51.10693359375}}, "feed_forward": {"intermediate_dense": {"bias": 1.8638681173324585, "kernel": 103.71176147460938}, "output_dense": {"bias": 0.7520318031311035, "kernel": 94.57533264160156}}, "final_layer_norm": {"bias": 2.5145840644836426, "scale": 20.836275100708008}, "layer_norm": {"bias": 2.0286154747009277, "scale": 23.157257080078125}}, "7": {"attention": {"k_proj": {"bias": 0.5047719478607178, "kernel": 51.46348190307617}, "out_proj": {"bias": 1.439831256866455, "kernel": 51.13825225830078}, "q_proj": {"bias": 2.5507259368896484, "kernel": 51.91937255859375}, "v_proj": {"bias": 0.4272071123123169, "kernel": 50.952415466308594}}, "feed_forward": {"intermediate_dense": {"bias": 1.874153733253479, "kernel": 103.49765014648438}, "output_dense": {"bias": 0.5875180959701538, "kernel": 94.39153289794922}}, "final_layer_norm": {"bias": 2.4167640209198, "scale": 21.010517120361328}, "layer_norm": {"bias": 1.9791065454483032, "scale": 22.196998596191406}}, "8": {"attention": {"k_proj": {"bias": 0.49701446294784546, "kernel": 51.12043762207031}, "out_proj": {"bias": 1.2548670768737793, "kernel": 51.65052795410156}, "q_proj": {"bias": 2.5421109199523926, "kernel": 51.03322982788086}, "v_proj": {"bias": 0.35404258966445923, "kernel": 51.66223907470703}}, "feed_forward": {"intermediate_dense": {"bias": 1.9299311637878418, "kernel": 103.21700286865234}, "output_dense": {"bias": 0.5481632947921753, "kernel": 93.97358703613281}}, "final_layer_norm": {"bias": 2.352778673171997, "scale": 20.739261627197266}, "layer_norm": {"bias": 1.9221278429031372, "scale": 22.404674530029297}}, "9": {"attention": {"k_proj": {"bias": 0.5226665735244751, "kernel": 52.009735107421875}, "out_proj": {"bias": 1.4968758821487427, "kernel": 52.6712646484375}, "q_proj": {"bias": 2.4627773761749268, "kernel": 52.267112731933594}, "v_proj": {"bias": 0.3845318853855133, "kernel": 52.86528396606445}}, "feed_forward": {"intermediate_dense": {"bias": 2.0269086360931396, "kernel": 101.98193359375}, "output_dense": {"bias": 0.6826795935630798, "kernel": 94.36759948730469}}, "final_layer_norm": {"bias": 2.325364828109741, "scale": 20.160507202148438}, "layer_norm": {"bias": 2.0237574577331543, "scale": 24.083885192871094}}}, "pos_conv_embed": {"conv": {"bias": 5.846972465515137, "weight_g": 9.124374389648438, "weight_v": 93.52505493164062}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.3762969970703125, "scale": 16.443078994750977}, "projection": {"bias": 1.866883397102356, "kernel": 37.217613220214844}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 4.088161585968919e-05, "train/loss": 0.11311191320419312, "train/param_norm": 1241.630615234375, "_runtime": 3024, "_timestamp": 1660147680, "_step": 275100, "_wandb": {"runtime": 3025}} \ No newline at end of file diff --git a/wandb/run-20220810_151736-2jo5la5b/logs/debug-internal.log b/wandb/run-20220810_151736-2jo5la5b/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..a408227e12ff4b1e579f480b6a76799f245632c5 --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:125838a3c175d34fe31998e2afe0d94313eee71844f52c2b76771510894a6441 +size 118448 diff --git a/wandb/run-20220810_151736-2jo5la5b/logs/debug.log b/wandb/run-20220810_151736-2jo5la5b/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..0b438c341cea83f62aad56827c0c438888058b24 --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64280f619705177b5c06f8c125f22f5dd371ca22c3e452fbdcf399ec2896cb54 +size 6286 diff --git a/wandb/run-20220810_151736-2jo5la5b/run-2jo5la5b.wandb b/wandb/run-20220810_151736-2jo5la5b/run-2jo5la5b.wandb new file mode 100644 index 0000000000000000000000000000000000000000..46eb7bba8ff8d4394d113dd55b3824a01e64ec9a --- /dev/null +++ b/wandb/run-20220810_151736-2jo5la5b/run-2jo5la5b.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e23b59df66991dd3f2ccba945e8e71432f5136ba616d35486e15fca2c13f5864 +size 460658 diff --git a/wandb/run-20220811_082319-hrpkniwr/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220811_082319-hrpkniwr/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..688d0067bf5a29e7f75a50534392557c8a80a709 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1633 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.jit(jax.pmap(train_step, "batch", donate_argnums=(0,))) + + if training_args.do_eval: + p_eval_step = jax.jit(jax.pmap(eval_step, "batch")) + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + # Create sampling rng + rng, input_rng = jax.random.split(rng) + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + p_train_step.clear_cache() + p_eval_step.clear_cache() + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220811_082319-hrpkniwr/files/config.yaml b/wandb/run-20220811_082319-hrpkniwr/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e2539d903cff2e0f0493ae0d6085c4456837209 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660206199 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220811_082319-hrpkniwr/files/diff.patch b/wandb/run-20220811_082319-hrpkniwr/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..3ba35c08d5b2a65131009fcf90f9051e6119936a --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/diff.patch @@ -0,0 +1,172 @@ +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..632a336 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..688d006 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1441,10 +1442,10 @@ def main(): + + # Create parallel version of the train and eval step + if training_args.do_train: +- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) ++ p_train_step = jax.jit(jax.pmap(train_step, "batch", donate_argnums=(0,))) + + if training_args.do_eval: +- p_eval_step = jax.pmap(eval_step, "batch") ++ p_eval_step = jax.jit(jax.pmap(eval_step, "batch")) + + def run_evaluation(step): + if training_args.do_eval: +@@ -1520,11 +1521,10 @@ def main(): + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +@@ -1565,6 +1565,8 @@ def main(): + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) ++ p_train_step.clear_cache() ++ p_eval_step.clear_cache() + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..96287dc 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,62 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..ce349ec 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220811_082319-hrpkniwr/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..b60911c 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220811_082319-hrpkniwr/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..78672c4 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220811_082319-hrpkniwr +\ No newline at end of file diff --git a/wandb/run-20220811_082319-hrpkniwr/files/output.log b/wandb/run-20220811_082319-hrpkniwr/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..53f217c0821c4ac53395c26b65a6d35217a004b5 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d68285d3ad7a01eab94bf633e13fe2dba88257cd58a47c1977fbe8ff5ec38fc +size 213305 diff --git a/wandb/run-20220811_082319-hrpkniwr/files/requirements.txt b/wandb/run-20220811_082319-hrpkniwr/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220811_082319-hrpkniwr/files/wandb-metadata.json b/wandb/run-20220811_082319-hrpkniwr/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..6b30f794ede2c972624cf908be791fe4784d4d32 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/wandb-metadata.json @@ -0,0 +1,69 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-11T08:23:22.887427", + "startedAt": "2022-08-11T08:23:19.499809", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=./", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=half_mixed", + "--matmul_precision=bfloat16", + "--learning_rate=6.394633237505332e-05", + "--skip_steps=275000", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=30.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--do_lower_case", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)", + "--remove_punctuation" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220811_082319-hrpkniwr/files/wandb-summary.json b/wandb/run-20220811_082319-hrpkniwr/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..7901fb56deabb8234b6a43cc6e18523e5c53be48 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/files/wandb-summary.json @@ -0,0 +1 @@ +{"_wandb": {"runtime": 1031}} \ No newline at end of file diff --git a/wandb/run-20220811_082319-hrpkniwr/logs/debug-internal.log b/wandb/run-20220811_082319-hrpkniwr/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..d580475b7c9de038c8c6dfc41fc02b66563e4dce --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef1bf6cae1121622fefe50860842e78907f5616f8501c14952d31a1146fb1ac5 +size 49317 diff --git a/wandb/run-20220811_082319-hrpkniwr/logs/debug.log b/wandb/run-20220811_082319-hrpkniwr/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..f64599fcda455b394fe39a9306ee0acbb9162ef1 --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fd14bc9f85133b808cef62e3a6382de461ebc53a51bea6d20e9a1fea8a355da +size 5853 diff --git a/wandb/run-20220811_082319-hrpkniwr/run-hrpkniwr.wandb b/wandb/run-20220811_082319-hrpkniwr/run-hrpkniwr.wandb new file mode 100644 index 0000000000000000000000000000000000000000..b6d5bf416d90463606821ee74698bb7db307fd5d --- /dev/null +++ b/wandb/run-20220811_082319-hrpkniwr/run-hrpkniwr.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c29443be2ca02227f8767117dc308d208949e0eac7487040d8d178a935d5aadb +size 232091 diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220811_085413-2dwhhb1y/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..688d0067bf5a29e7f75a50534392557c8a80a709 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1633 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.jit(jax.pmap(train_step, "batch", donate_argnums=(0,))) + + if training_args.do_eval: + p_eval_step = jax.jit(jax.pmap(eval_step, "batch")) + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + # Create sampling rng + rng, input_rng = jax.random.split(rng) + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + p_train_step.clear_cache() + p_eval_step.clear_cache() + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/config.yaml b/wandb/run-20220811_085413-2dwhhb1y/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f4349bd8a2b51ff48656d0cb152939619529525 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/config.yaml @@ -0,0 +1,33 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660208053 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 2: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/diff.patch b/wandb/run-20220811_085413-2dwhhb1y/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..9c7f970da572ec58b74613cc714e80b3e6257382 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/diff.patch @@ -0,0 +1,226 @@ +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..632a336 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run.sh b/run.sh +index 8758978..6adf9ee 100755 +--- a/run.sh ++++ b/run.sh +@@ -1,3 +1,6 @@ ++# See https://github.com/sanchit-gandhi/seq2seq-speech/issues/23#issuecomment-1122183173: do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab ++# Let's also not add extra remove_punctuation ++# And limit max duration to 25 seconds + WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \ + --model_name_or_path="facebook/wav2vec2-xls-r-1b" \ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \ +@@ -11,7 +14,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --precision="full_mixed" \ + --matmul_precision="bfloat16" \ + --multisteps \ +- --learning_rate="1e-4" \ ++ --learning_rate="2e-5" \ + --warmup_steps="2000" \ + --length_column_name="input_length" \ + --evaluation_strategy="steps" \ +@@ -32,7 +35,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --mask_feature_length="64" \ + --gradient_checkpointing \ + --min_duration_in_seconds="0.5" \ +- --max_duration_in_seconds="30.0" \ ++ --max_duration_in_seconds="25.0" \ + --use_auth_token \ + --seed="42" \ + --group_by_length \ +@@ -40,10 +43,5 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --push_to_hub \ + --preprocessing_num_workers="32" \ + --ctc_zero_infinity \ +- --do_lower_case \ + --wandb_project="wav2vec2" \ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \ +- --remove_punctuation +- +- +-# --fp16 +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..688d006 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1441,10 +1442,10 @@ def main(): + + # Create parallel version of the train and eval step + if training_args.do_train: +- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) ++ p_train_step = jax.jit(jax.pmap(train_step, "batch", donate_argnums=(0,))) + + if training_args.do_eval: +- p_eval_step = jax.pmap(eval_step, "batch") ++ p_eval_step = jax.jit(jax.pmap(eval_step, "batch")) + + def run_evaluation(step): + if training_args.do_eval: +@@ -1520,11 +1521,10 @@ def main(): + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +@@ -1565,6 +1565,8 @@ def main(): + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) ++ p_train_step.clear_cache() ++ p_eval_step.clear_cache() + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..5807947 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,76 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..3056b1b 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220811_085413-2dwhhb1y/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..c95117f 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220811_085413-2dwhhb1y/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..3b0aef7 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220811_085413-2dwhhb1y +\ No newline at end of file diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/output.log b/wandb/run-20220811_085413-2dwhhb1y/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..37836e625018c63b5576226a66a4f6eaf54517fb --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2de00b0f75559621b80c8337af18e5f3c5273eb74b22ef2ea73d6d99365a6034 +size 200283 diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/requirements.txt b/wandb/run-20220811_085413-2dwhhb1y/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/wandb-metadata.json b/wandb/run-20220811_085413-2dwhhb1y/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9e9755cf0ad16492c0b9bd724031e2d773711c78 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/wandb-metadata.json @@ -0,0 +1,67 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-11T08:54:16.572334", + "startedAt": "2022-08-11T08:54:13.207979", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=facebook/wav2vec2-xls-r-1b", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=full_mixed", + "--matmul_precision=bfloat16", + "--multisteps", + "--learning_rate=2e-5", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=25.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220811_085413-2dwhhb1y/files/wandb-summary.json b/wandb/run-20220811_085413-2dwhhb1y/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..fba38fd28e3776a9eb2ccdc53ac547b661b10580 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/files/wandb-summary.json @@ -0,0 +1 @@ +{"_wandb": {"runtime": 685}} \ No newline at end of file diff --git a/wandb/run-20220811_085413-2dwhhb1y/logs/debug-internal.log b/wandb/run-20220811_085413-2dwhhb1y/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..f8fa34313eeacfdd6027629033b3bc2731607cc9 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a935ea8dbbd11ad00913a95e7b66e1d2997ef22109708f5da6218cbf5ac5195 +size 37143 diff --git a/wandb/run-20220811_085413-2dwhhb1y/logs/debug.log b/wandb/run-20220811_085413-2dwhhb1y/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..6ee23d2fc6f4d2def46aa126f820bb3539d393f9 --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ae3bdb6afe081d53ccecbb8a8d4d88c8cbbc86e4c0d035151f45c5ac18d6d01 +size 5815 diff --git a/wandb/run-20220811_085413-2dwhhb1y/run-2dwhhb1y.wandb b/wandb/run-20220811_085413-2dwhhb1y/run-2dwhhb1y.wandb new file mode 100644 index 0000000000000000000000000000000000000000..a7eee06aba40381444978343265b11da23eec4ce --- /dev/null +++ b/wandb/run-20220811_085413-2dwhhb1y/run-2dwhhb1y.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4abe7d114964a72aec9188cda305ac78a325cc9bc623a028d570032fee762533 +size 211934 diff --git a/wandb/run-20220811_094956-332xvl6v/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220811_094956-332xvl6v/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5d9404604defc9d7cdab04832844ec56ce7978 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1631 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220811_094956-332xvl6v/files/config.yaml b/wandb/run-20220811_094956-332xvl6v/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..833380814f49a07c3843551724eeb029a04a58ed --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/config.yaml @@ -0,0 +1,27 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660211396 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220811_094956-332xvl6v/files/diff.patch b/wandb/run-20220811_094956-332xvl6v/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..4db9ec36f3981a56f0b4a57b059c739143cafaf8 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/diff.patch @@ -0,0 +1,234 @@ +diff --git a/config.json b/config.json +index 260219f..246b797 100644 +--- a/config.json ++++ b/config.json +@@ -5,7 +5,7 @@ + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ +- "Wav2Vec2ForCTC" ++ "Wav2Vec2ForPreTraining" + ], + "attention_dropout": 0.094, + "bos_token_id": 1, +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..632a336 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run.sh b/run.sh +index 8758978..6adf9ee 100755 +--- a/run.sh ++++ b/run.sh +@@ -1,3 +1,6 @@ ++# See https://github.com/sanchit-gandhi/seq2seq-speech/issues/23#issuecomment-1122183173: do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab ++# Let's also not add extra remove_punctuation ++# And limit max duration to 25 seconds + WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \ + --model_name_or_path="facebook/wav2vec2-xls-r-1b" \ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \ +@@ -11,7 +14,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --precision="full_mixed" \ + --matmul_precision="bfloat16" \ + --multisteps \ +- --learning_rate="1e-4" \ ++ --learning_rate="2e-5" \ + --warmup_steps="2000" \ + --length_column_name="input_length" \ + --evaluation_strategy="steps" \ +@@ -32,7 +35,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --mask_feature_length="64" \ + --gradient_checkpointing \ + --min_duration_in_seconds="0.5" \ +- --max_duration_in_seconds="30.0" \ ++ --max_duration_in_seconds="25.0" \ + --use_auth_token \ + --seed="42" \ + --group_by_length \ +@@ -40,10 +43,5 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --push_to_hub \ + --preprocessing_num_workers="32" \ + --ctc_zero_infinity \ +- --do_lower_case \ + --wandb_project="wav2vec2" \ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \ +- --remove_punctuation +- +- +-# --fp16 +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..4a5d940 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1517,14 +1518,13 @@ def main(): + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..0d13bc3 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,90 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..a94dcab 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220811_094956-332xvl6v/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..dd4b72e 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220811_094956-332xvl6v/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..e0f7642 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220811_094956-332xvl6v +\ No newline at end of file diff --git a/wandb/run-20220811_094956-332xvl6v/files/output.log b/wandb/run-20220811_094956-332xvl6v/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..a28de30aa7774f0c7ade5098d1ea4be5ac7c3d7a --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c17805a65ad9a53223277a717b9895858af11b01f40ac61f45c9f058c9c362f +size 106496 diff --git a/wandb/run-20220811_094956-332xvl6v/files/requirements.txt b/wandb/run-20220811_094956-332xvl6v/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220811_094956-332xvl6v/files/wandb-metadata.json b/wandb/run-20220811_094956-332xvl6v/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..ecf3908c7ac251615784a4864f36c12652a5faf7 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/wandb-metadata.json @@ -0,0 +1,67 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-11T09:50:00.091126", + "startedAt": "2022-08-11T09:49:56.562575", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=facebook/wav2vec2-xls-r-1b", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=full_mixed", + "--matmul_precision=bfloat16", + "--multisteps", + "--learning_rate=2e-5", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=25.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220811_094956-332xvl6v/files/wandb-summary.json b/wandb/run-20220811_094956-332xvl6v/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..9e26dfeeb6e641a33dae4961196235bdb965b21b --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/files/wandb-summary.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/wandb/run-20220811_094956-332xvl6v/logs/debug-internal.log b/wandb/run-20220811_094956-332xvl6v/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..3c185c189b5e11bfb52b1bf3258ebfedb1337999 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fc922c34a1d3d878f8a8bc86e3cd037eb1665ca2d693a2fb38212edd5d8ccf6 +size 20480 diff --git a/wandb/run-20220811_094956-332xvl6v/logs/debug.log b/wandb/run-20220811_094956-332xvl6v/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..ed051f2469df924da8842aef16e2ca0d84436b49 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68c2234eda5ccf1e451e2ba85a5b277d4e40f76b86ca6f99b79e397f6d6f2418 +size 2952 diff --git a/wandb/run-20220811_094956-332xvl6v/run-332xvl6v.wandb b/wandb/run-20220811_094956-332xvl6v/run-332xvl6v.wandb new file mode 100644 index 0000000000000000000000000000000000000000..fc82639e9c67a5136196de416f14f1dbfea66336 --- /dev/null +++ b/wandb/run-20220811_094956-332xvl6v/run-332xvl6v.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95a80db2c22951539a6b6768193829dfe94b038f2067858337e735b711603b19 +size 110592 diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/code/run_flax_speech_recognition_ctc.py b/wandb/run-20220811_101752-mzjvp6ho/files/code/run_flax_speech_recognition_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5d9404604defc9d7cdab04832844ec56ce7978 --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/code/run_flax_speech_recognition_ctc.py @@ -0,0 +1,1631 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import math +import os +import re +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +from datasets import DatasetDict, load_dataset, load_metric +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +import transformers +import wandb as wandb +from flax import core, jax_utils, struct, traverse_util +from flax.jax_utils import unreplicate, pad_shard_unpad +from flax.training.common_utils import get_metrics, shard, shard_prng_key +from huggingface_hub import Repository +from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC +from optax._src import linear_algebra +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.17.0.dev0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + attention_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} + ) + activation_dropout: float = field( + default=0.1, + metadata={ + "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler." + }, + ) + hidden_dropout: float = field( + default=0.1, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + feat_proj_dropout: float = field( + default=0.0, + metadata={ + "help": "The feat proj dropout probability for feature encoder representations." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.1, + metadata={ + "help": "The spec aug dropout probability for feature encoder representations." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} + ) + ctc_zero_infinity: Optional[bool] = field( + default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + dataset_cache_dir: Optional[str] = field( + default=None, metadata={"help": "Path to cache directory for saving and loading datasets"} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + max_label_length: Optional[int] = field( + default=512, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + min_label_length: Optional[int] = field( + default=2, + metadata={ + "help": "The minimum total sequence length for target text after tokenization. Sequences shorter " + "than this will be filtered." + }, + ) + pad_input_to_multiple_of: Optional[int] = field( + default=32000, + metadata={ + "help": "If set will pad the input sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + pad_target_to_multiple_of: Optional[int] = field( + default=None, + metadata={ + "help": "If set will pad the target sequence to a multiple of the provided value. " + "This is important to avoid triggering recompilations on TPU." + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + wandb_project: str = field( + default="flax-speech-recognition-ctc", + metadata={"help": "The name of the wandb project."}, + ) + wandb_name: str = field( + default=None, + metadata={"help": "The name of the wandb run."}, + ) + wandb_job_type: str = field( + default="CTC", + metadata={"help": "The name of the wandb job type."}, + ) + test_split_name: str = field( + default="test", + metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"}, + ) + remove_punctuation: bool = field( + default=False, metadata={"help": "Whether or not to remove punctuation during training."} + ) + skip_steps: Optional[int] = field( + default=0, + metadata={ + "help": "Skip this number of steps. Useful to continue training" + }, + ) + + +# @flax.struct.dataclass +@dataclass +class FlaxTrainingArguments(TrainingArguments): + precision: str = field( + default="full", + metadata={ + "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision" + "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**" + }, + ) + matmul_precision: str = field( + default="default", + metadata={ + "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. " + "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). " + "This configuration option does not change the behaviours of such calls with explicit precision arguments; " + "it only changes the behaviors of calls with no such argument provided. " + "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`." + }, + ) + multisteps: bool = field( + default=False, + metadata={ + "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, " + "a custom gradient accumulation implementation will be employed." + }, + ) + + +def to_fp32(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) + + +def to_bf16(t): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) + + +class MixedPrecisionTrainState(struct.PyTreeNode): + """Train state for use with a single Optax optimizer. + Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py + + Synopsis:: + + state = TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) + grad_fn = jax.grad(make_loss_fn(state.apply_fn)) + for batch in data: + grads = grad_fn(state.params, batch) + state = state.apply_gradients(grads=grads) + + Args: + step: Counter starts at 0 and is incremented by every call to + `.apply_gradients()`. + apply_fn: Usually set to `model.apply()`. Kept in this dataclass for + convenience to have a shorter params list for the `train_step()` function + in your training loop. + params: The parameters to be updated by `tx` and used by `apply_fn`. + tx: An Optax gradient transformation. + opt_state: The state for `tx`. + dropout_rng: PRNG key for stochastic operations. + bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. + """ + + step: int + apply_fn: Callable = struct.field(pytree_node=False) + get_attention_mask_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + dropout_rng: jnp.ndarray + max_grad_norm: Optional[float] = 1.0 + + def apply_gradients(self, *, grads, to_dtype, **kwargs): + """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. + + Note that internally this function calls `.tx.update()` followed by a call + to `optax.apply_updates()` to update `params` and `opt_state`. + + Args: + grads: Gradients that have the same pytree structure as `.params`. + **kwargs: Additional dataclass attributes that should be `.replace()`-ed. + + Returns: + An updated instance of `self` with `step` incremented by one, `params` + and `opt_state` updated by applying `grads`, and additional attributes + replaced as specified by `kwargs`. + """ + + # clip gradients by global l2 norm + casted_max_grad_norm = to_dtype(self.max_grad_norm) + g_norm = linear_algebra.global_norm(grads) + g_norm = jnp.maximum(casted_max_grad_norm, g_norm) + grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads) + + # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training + # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is) + updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params) + + new_params = optax.apply_updates(self.params, updates) + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=to_dtype(new_opt_state), + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( + step=step, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +@flax.struct.dataclass +class FlaxDataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (:obj: `int`) + The begin-of-sentence of the decoder. + input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned input sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned target sequences (according to the model's padding side and padding index). + See above for details. + max_input_length (:obj:`float`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_input_to_multiple_of (:obj:`int`, `optional`): + If set will pad the input sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + pad_target_to_multiple_of (:obj:`int`, `optional`): + If set will pad the target sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Any + input_padding: Union[bool, str] = "longest" + label_padding: Union[bool, str] = "max_length" + pad_input_to_multiple_of: Optional[int] = None + pad_to_multiple_of_label: Optional[int] = None + max_input_length: Optional[float] = None + max_label_length: Optional[float] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + # reformat list to dict and set to pytorch format + batch = self.processor.feature_extractor.pad( + input_features, + max_length=self.max_input_length, + padding=self.input_padding, + pad_to_multiple_of=self.pad_input_to_multiple_of, + return_tensors="np", + ) + + labels_batch = self.processor.tokenizer.pad( + label_features, + max_length=self.max_label_length, + padding=self.label_padding, + pad_to_multiple_of=self.pad_to_multiple_of_label, + return_tensors="np", + ) + + labels = labels_batch["input_ids"] + labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1)) + labels = labels.filled(fill_value=-100) + + batch["labels"] = labels + + return batch + + +def get_grouped_indices( + dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None +) -> np.array: + """ + Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486) + Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted (if a JAX rng is specified) + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + lengths = dataset["input_length"] + + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. + num_samples = len(lengths) + indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = np.argmax(megabatch_maximums).item() + # Switch to put the longest batch in first position + # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch) + megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0] + + megabatches = np.array([i for megabatch in megabatches for i in megabatch]) + + return megabatches + + +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" + num_samples = len(samples_idx) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if pred_str is not None: + # write output actual predictions for debugging + summary_writer.text("eval_predictions", "\n".join(pred_str), step) + + +def write_wandb_log(metrics, step, prefix=None): + if jax.process_index() == 0: + log_metrics = {} + for k, v in metrics.items(): + if "layer" in k: + log_metrics[f"{k}/"] = v + elif prefix is not None: + log_metrics[f"{prefix}/{k}"] = v + else: + log_metrics[k] = v + wandb.log(log_metrics, step) + + +def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"): + if jax.process_index() == 0: + # convert str data to a wandb compatible format + str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))] + # we'll log the first 50 predictions for each epoch + wandb.log( + { + f"{prefix}/step_{int(step / 1000)}k": wandb.Table( + columns=["label_str", "pred_str"], data=str_data[:num_log] + ) + }, + step, + ) + + +def create_learning_rate_fn( + num_train_steps: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def ctc_loss( + logits, + logits_attention_mask, + labels, + blank_id, + loss_reduction="mean", + output_emission_dict=False, + log_epsilon=-100000.0, +): + """Computes CTC loss. + This function performs forward computation over an FSA with `N * 2` states + where `N` is the max number of labels. The states are split into two groups: + Phi states and emission states. a phi-state accepts repetition of + phi (blank)-symbols and transits to emission state when the correct label is + observed. An emission state accepts repetition of the label and transits to + the next phi states at any time (so called epsilon-transition). + Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, + and `N` denotes the time steps in `labels`. + Args: + logits: (B, T, K)-array containing log-probabilities of each class. + logitpaddings: (B, T)-array. Padding indicators for `logits`. + labels: (B, N)-array containing reference integer labels. + labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, + `labels` must be right-padded, i.e. each row of `labelpaddings` must be + repetition of zeroes, followed by repetition of ones. + blank_id: Id for blank token. + loss_reduction: one of "mean", "sum", "default" + - "none": no reduction is applied. + - "mean": output loss will be divided by target lengths and then the + mean over the batch is taken. + - "sum": output loss are summed over batch + output_emission_dict: whether to output additional information about the emission probs + Returns: + A pair of `(per_seq_loss, aux)`. + per_seq_loss: + (B,)-array containing loss values for each sequence in the batch. + aux: Dictionary containing interim variables used for computing losses. + aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each + phi-state corresponding to the n-th label. + aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each + emission-state corresponding to the n-th label. + aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol + corresponding to each time frame. + aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label + corresponding to each time frame. + """ + # label paddings are indicated by -100 + labelpaddings = labels < 0 + # logit paddings are the inverse of attention_mask + logitpaddings = ~logits_attention_mask + + # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_, maxlabellen = labels.shape + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N] + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat)) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = next_phi.at[:, 1:].set( + jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + ) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot) + + if loss_reduction == "mean": + target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1) + loss = (per_seq_loss / target_lengths).mean() + elif loss_reduction == "sum": + loss = per_seq_loss.sum() + else: + loss = per_seq_loss + + if not output_emission_dict: + return loss + + return loss, { + "logalpha_phi": logalpha_phi, + "logalpha_emit": logalpha_emit, + "logprobs_phi": logprobs_phi, + "logprobs_emit": logprobs_emit, + } + + +def make_dataset(data_args, seed=42): + # Pre-processing dataset + import re + + def map_nst(entry): + text = entry["text"].lower() + text = text.replace("(...vær stille under dette opptaket...)", "") + text = re.sub('[áàâ]', 'a', text) + text = re.sub('[ä]', 'æ', text) + text = re.sub('[éèëê]', 'e', text) + text = re.sub('[íìïî]', 'i', text) + text = re.sub('[óòöô]', 'o', text) + text = re.sub('[ö]', 'ø', text) + text = re.sub('[ç]', 'c', text) + text = re.sub('[úùüû]', 'u', text) + # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text) + text = re.sub('\s+', ' ', text) + return {"text": text} + + def filter_nst(entry): + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.match(entry["type"], "pIW|CA"): + return False # Spelling out words + return True + + def filter_npsc(entry): + # False if there are digits in the text + if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)): + return False # Too short + if re.search("\d", entry["text"]): + return False + return True + + def map_npsc(entry): + batch = {"text": entry["text"].lower()} + batch["text"] = re.sub('[áàâ]', 'a', batch["text"]) + batch["text"] = re.sub('[ä]', 'æ', batch["text"]) + batch["text"] = re.sub('[éèëê]', 'e', batch["text"]) + batch["text"] = re.sub('[íìïî]', 'i', batch["text"]) + batch["text"] = re.sub('[óòöô]', 'o', batch["text"]) + batch["text"] = re.sub('[ö]', 'ø', batch["text"]) + batch["text"] = re.sub('[ç]', 'c', batch["text"]) + batch["text"] = re.sub('[úùüû]', 'u', batch["text"]) + batch["text"] = re.sub('\s', ' ', batch["text"]) + batch["text"] = re.sub('', 'eee', batch["text"]) + batch["text"] = re.sub('', 'qqq', batch["text"]) + batch["text"] = re.sub('', 'mmm', batch["text"]) + batch["text"] = re.sub('', 'xxx', batch["text"]) + # batch["text"] = re.sub('', '?', batch["text"]) + if "<" in batch["text"]: + raise ValueError(batch["text"]) + return batch + + nst = datasets.load_dataset("NbAiLab/NST", "no-close") + npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3") + # TODO NST_hesitate + + split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC + nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed) + nst[data_args.train_split_name] = nst_train["train"] + nst[data_args.eval_split_name] = nst_train["test"] + + nst = nst.filter(filter_nst).map( + map_nst, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NST", + ).shuffle(seed=seed) + npsc = npsc.filter(filter_npsc).map( + map_npsc, + num_proc=data_args.preprocessing_num_workers, + desc="filtering NPSC", + ).shuffle(seed=seed) + + npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]]) + + combined = {} + for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name: + probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples + probs = (probs / probs.sum()).tolist() + comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed) + combined[split] = comb + + return datasets.DatasetDict(**combined) + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + # Set the verbosity to info of the Transformers logger. + # We only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set up wandb run + if jax.process_index() == 0: + wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type) + + logger.info("Training/evaluation parameters %s", training_args) + + # Set the default TPU matmul precision and display the number of devices + jax.config.update("jax_default_matmul_precision", training_args.matmul_precision) + logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}") + + # 4. Load dataset + + set_seed(training_args.seed) + raw_datasets = make_dataset(data_args, seed=training_args.seed) + + # raw_datasets = DatasetDict() + + # if training_args.do_train: + # raw_datasets[data_args.train_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.train_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_eval: + # raw_datasets[data_args.eval_split_name] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=data_args.eval_split_name, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + # if training_args.do_predict: + # test_split = data_args.test_split_name.split("+") + # for split in test_split: + # raw_datasets[split] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=split, + # cache_dir=data_args.dataset_cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + + if not training_args.do_train and not training_args.do_eval and not training_args.do_predict: + raise ValueError( + "Cannot not train, not do evaluation and not do prediction. At least one of " + "training, evaluation or prediction has to be done." + ) + + # if not training, there is no need to run multiple epochs + if not training_args.do_train: + training_args.num_train_epochs = 1 + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = Wav2Vec2Config.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # update config according to training args, model args, and tokenizer attributes + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": tokenizer.vocab_size, # len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr": + raise ValueError( + "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to " + "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus," + "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely " + "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`." + ) + + if training_args.precision == "full_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = True + elif training_args.precision == "half_mixed": + dtype = jnp.bfloat16 + training_args.mixed_precision = False + else: + dtype = jnp.float32 + training_args.mixed_precision = False + + try: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + model = FlaxWav2Vec2ForCTC.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + from_pt=True, + ) + + # 6. Resample speech dataset ALWAYS + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate) + max_target_length = data_args.max_label_length + min_target_length = data_args.min_label_length + pad_input_to_multiple_of = data_args.pad_input_to_multiple_of + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + dataset_name = data_args.dataset_name + chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ") + chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]' + # gigaspeech_punctuation = {" ": ",", " ": ".", " ": "?", " ": "!"} + # gigaspeech_disfluencies = ["", ""] + # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", + # "[vocalized-noise]", "_1"] + # swb_punctuations = ["{", "}", "[", "]-", "]"] + # earnings_disfluencies = ["", "", "", "inaudible", "", ""] + ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", + "[vocalized-noise]", "", "", "", "", "", "", ""] + + if training_args.do_train and data_args.max_train_samples is not None: + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples)) + + if training_args.do_eval and data_args.max_eval_samples is not None: + raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_predict and data_args.max_test_samples is not None: + raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples)) + + if training_args.do_train and data_args.remove_punctuation: + + def remove_punctuation(batch): + batch[text_column_name] = ( + re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "") + ) + + raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map( + remove_punctuation, + num_proc=data_args.preprocessing_num_workers, + desc="removing punctuation from train split", + ) + + # filter data where the targets are ignored in scoring + def is_target_labels(input_str): + return input_str.lower() not in ignore_segments + + raw_datasets = raw_datasets.filter( + is_target_labels, + num_proc=num_workers, + input_columns=[text_column_name], + desc="filtering data where the targets are ignored in scoring", + ) + + def prepare_dataset(batch): + # process audio + try: + sample = batch[audio_column_name] + except ValueError: + sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate} + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + + # if dataset_name == "google/xtreme_s": + # # Finally, we tokenize the processed text + # batch["labels"] = tokenizer(input_str).input_ids + # batch["labels_length"] = len(batch["labels"]) + # return batch + + # # Common Voice 9 + # if input_str.startswith('"') and input_str.endswith('"'): + # # we can remove trailing quotation marks as they do not affect the transcription + # input_str = input_str[1:-1] + # # normalize quotation marks + # input_str = re.sub(r'["“”]', '"', input_str) + # # normalize apostrophes + # input_str = re.sub(r"[’']", "'", input_str) + # # normalize hyphens + # input_str = re.sub(r"[—–]", "-", input_str) + # # replace double quotation marks with single + # input_str = input_str.replace('""', '"') + # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str): + # # for CV9, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # TEDLIUM-3 + # # delete the token from the text and replace spaced apostrophes with un-spaced + # input_str = input_str.replace("", "").replace(" '", "'") + + # # GigaSpeech + # for disfluency in gigaspeech_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # convert spelled out punctuation to symbolic form + # for punctuation, replacement in gigaspeech_punctuation.items(): + # input_str = input_str.replace(punctuation, replacement) + # if dataset_name == "speechcolab/gigaspeech" and len(input_str): + # # for GS, we'll normalize the text to always finish with punctuation + # if input_str[-1] not in [".", "?", "!"]: + # input_str = input_str + "." + + # # SWB + # for disfluency in swb_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # remove parenthesised text (test data only) + # input_str = re.sub("[\(].*?[\)]", "", input_str) + # for punctuation in swb_punctuations: + # input_str = input_str.replace(punctuation, "") + # # replace anomalous words with their correct transcriptions + # split_str = input_str.split("/") + # if len(split_str) > 1: + # input_str = " ".join( + # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) + + # # Earnings 22 + # for disfluency in earnings_disfluencies: + # input_str = input_str.replace(disfluency, "") + # # replace mal-formatted ellipsis + # input_str = input_str.replace("…", ".") + + # JIWER compliance + # remove multiple spaces + input_str = re.sub(r"\s\s+", " ", input_str) + # strip trailing spaces + input_str = input_str.strip() + + # Finally, we tokenize the processed text + batch["labels"] = tokenizer(input_str).input_ids + batch["labels_length"] = len(batch["labels"]) + return batch + + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess dataset", + ) + + # filter data with inputs shorter than min_input_length or longer than max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # filter data with targets shorter than min_target_length or longer than max_target_length + def is_labels_in_length_range(length): + return length > min_target_length # and length < max_target_length + + vectorized_datasets = vectorized_datasets.filter( + is_labels_in_length_range, + num_proc=num_workers, + input_columns=["labels_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metrics + wer_metric = load_metric("wer") + cer_metric = load_metric("cer") + + def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): + padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids)) + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(padded_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer, "cer": cer}, pred_str, label_str + + # 9. save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + input_padding="longest", + pad_input_to_multiple_of=pad_input_to_multiple_of, + max_label_length=data_args.max_label_length, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run `pip install tensorboard` to enable." + ) + + # 10. Handle the repository creation + if training_args.push_to_hub: + with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f: + git_lfs_extensions = f.read() + if "*.wandb" not in git_lfs_extensions: + f.write("*.wandb filter=lfs diff=lfs merge=lfs -text") + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # 11. Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constants + max_steps = int(training_args.max_steps) + gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 + + if training_args.do_train: + num_train_samples = len(vectorized_datasets[data_args.train_split_name]) + steps_per_epoch = num_train_samples // batch_size_per_update + if max_steps > 0: + num_epochs = -(training_args.max_steps // -steps_per_epoch) + total_train_steps = max_steps + else: + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + total_train_steps, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") + for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + if training_args.adafactor: + # Create Adafactor optimizer + optim = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32, + weight_decay_rate=training_args.weight_decay, + weight_decay_mask=decay_mask_fn, + ) + else: + # Create AdamW optimizer + optim = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1) + if training_args.multisteps and gradient_accumulation_steps > 1: + optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False) + else: + num_epochs = 0 + total_train_steps = 0 + num_train_samples = 0 + optim = None + + # Setup train state + state = MixedPrecisionTrainState.create( + step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, + tx=optim, + to_dtype=to_dtype, + dropout_rng=dropout_rng, + max_grad_norm=training_args.max_grad_norm, + ) + + # Replicate the train state on each device + state = state.replicate() + blank_id = model.config.pad_token_id + + # Define gradient update step fn + def train_step(state, batch): + # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params, minibatch): + labels = minibatch.pop("labels") + logits = state.apply_fn( + **minibatch, + params=params, + dropout_rng=dropout_rng, + freeze_feature_encoder=model_args.freeze_feature_encoder, + train=True, + )[0] + logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + if gradient_accumulation_steps == 1 or training_args.multisteps: + loss, grad = grad_fn(to_dtype(state.params), batch) + + # Custom gradient accumulation + else: + # add a first dimension over gradient_accumulation_steps for minibatch slices + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::] + ), + batch, + ) + + def accum_minibatch_step(accum_grad, minibatch): + # compute loss, num labels and grad over minibatch and accumulate + loss, grad = grad_fn(to_dtype(state.params), minibatch) + return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss + + # create an initial state for accumulating losses, num labels and gradients + init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params)) + # loop accum minibatch step over the number of gradient accumulation steps + grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch) + + # update state + new_state = state.apply_gradients( + grads=grad, + dropout_rng=new_dropout_rng, + to_dtype=to_dtype, + ) + + # compute gradient norms over all layers and globally for detailed monitoring + layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad) + logs = { + "layer_grad_norm": layer_grad_norm, + "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)), + } + + # compute parameter norms over all layers and globally for detailed monitoring + layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params) + logs["layer_param_norm"] = layer_param_norm + logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm)) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics.update(logs) + + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"]) + loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean") + + pred_ids = jnp.argmax(logits, axis=-1) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + # metrics = to_fp32(metrics) + return metrics, pred_ids + + # Create parallel version of the train and eval step + if training_args.do_train: + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + if training_args.do_eval: + p_eval_step = jax.pmap(eval_step, "batch") + + def run_evaluation(step): + if training_args.do_eval: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, step, prefix="eval") + write_wandb_pred(pred_str, label_str, step) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str) + + def save_checkpoint(step): + # save and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False) + + skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_train_samples}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}") + logger.info(f" Use scan: {config.use_scan}") + logger.info(f" Fuse matmuls: {config.fuse_matmuls}") + logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)") + + train_time = cur_step = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + + if data_args.skip_steps > cur_step: + logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...") + # Gather the indices for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): + cur_step = epoch * (num_train_samples // batch_size_per_update) + step + if cur_step <= data_args.skip_steps: + continue + + samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + batch = shard(batch.data) + try: + state, train_metric = p_train_step(state, batch) + except TypeError as e: + logger.warning("Encountered following error: \n", e) + + + if cur_step % training_args.logging_steps == 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step + write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name) + # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis) + # if has_tensorboard and jax.process_index() == 0: + # write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})" + ) + + if cur_step % total_train_steps == 0: + break + + if training_args.eval_steps and cur_step % training_args.eval_steps == 0: + run_evaluation(cur_step) + + if cur_step % training_args.save_steps == 0: + save_checkpoint(cur_step) + + if training_args.eval_steps == 0 and (epoch + 1) != num_epochs: + # run evaluation at the end of the epoch if eval steps are not specified + run_evaluation(cur_step) + save_checkpoint(cur_step) + + if training_args.do_train: + save_checkpoint(cur_step) + + cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training + + if training_args.do_eval: + run_evaluation(cur_step) + + # TODO: collapse 'do_predict' into the run_evaluation function + if training_args.do_predict: + for split in [data_args.test_split_name]: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + # Generate eval set by sequentially sampling indices from the test dataset and grouping by length + eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): + samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] + batch = data_collator(samples) + labels = batch["labels"] + + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) + eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) + eval_metrics.append(metrics) + + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + eval_metrics = to_fp32(eval_metrics) + + # always run compute metrics + error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(error_rate_metric) + error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()]) + + # Print metrics and update progress bar + desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + write_wandb_log(eval_metrics, cur_step, prefix=split) + write_wandb_pred(pred_str, label_str, cur_step, prefix=split) + # if has_tensorboard and jax.process_index() == 0: + # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str) + + +if __name__ == "__main__": + main() diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/config.yaml b/wandb/run-20220811_101752-mzjvp6ho/files/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a2a2a61411f202b7b4cfdd4ce61deb01b3b5aa3 --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/config.yaml @@ -0,0 +1,27 @@ +wandb_version: 1 + +_wandb: + desc: null + value: + cli_version: 0.12.9 + code_path: code/run_flax_speech_recognition_ctc.py + framework: huggingface + huggingface_version: 4.21.0 + is_jupyter_run: false + is_kaggle_kernel: false + python_version: 3.8.10 + start_time: 1660213072 + t: + 1: + - 1 + - 2 + - 3 + - 11 + - 12 + 3: + - 13 + 4: 3.8.10 + 5: 0.12.9 + 6: 4.21.0 + 8: + - 5 diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/diff.patch b/wandb/run-20220811_101752-mzjvp6ho/files/diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..cd564161dff871e1a6dfd979afdaab6db8ac522a --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/diff.patch @@ -0,0 +1,234 @@ +diff --git a/config.json b/config.json +index 260219f..246b797 100644 +--- a/config.json ++++ b/config.json +@@ -5,7 +5,7 @@ + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ +- "Wav2Vec2ForCTC" ++ "Wav2Vec2ForPreTraining" + ], + "attention_dropout": 0.094, + "bos_token_id": 1, +diff --git a/run.recover.sh b/run.recover.sh +index 77ad3fd..632a336 100755 +--- a/run.recover.sh ++++ b/run.recover.sh +@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --gradient_accumulation_steps="1" \ +- --precision="full_mixed" \ ++ --precision="half_mixed" \ + --matmul_precision="bfloat16" \ +- --multisteps \ + --learning_rate="6.394633237505332e-05" \ + --skip_steps="275000" \ + --warmup_steps="2000" \ +diff --git a/run.sh b/run.sh +index 8758978..6adf9ee 100755 +--- a/run.sh ++++ b/run.sh +@@ -1,3 +1,6 @@ ++# See https://github.com/sanchit-gandhi/seq2seq-speech/issues/23#issuecomment-1122183173: do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab ++# Let's also not add extra remove_punctuation ++# And limit max duration to 25 seconds + WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \ + --model_name_or_path="facebook/wav2vec2-xls-r-1b" \ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \ +@@ -11,7 +14,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --precision="full_mixed" \ + --matmul_precision="bfloat16" \ + --multisteps \ +- --learning_rate="1e-4" \ ++ --learning_rate="2e-5" \ + --warmup_steps="2000" \ + --length_column_name="input_length" \ + --evaluation_strategy="steps" \ +@@ -32,7 +35,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --mask_feature_length="64" \ + --gradient_checkpointing \ + --min_duration_in_seconds="0.5" \ +- --max_duration_in_seconds="30.0" \ ++ --max_duration_in_seconds="25.0" \ + --use_auth_token \ + --seed="42" \ + --group_by_length \ +@@ -40,10 +43,5 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c + --push_to_hub \ + --preprocessing_num_workers="32" \ + --ctc_zero_infinity \ +- --do_lower_case \ + --wandb_project="wav2vec2" \ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \ +- --remove_punctuation +- +- +-# --fp16 +diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py +index a330879..4a5d940 100644 +--- a/run_flax_speech_recognition_ctc.py ++++ b/run_flax_speech_recognition_ctc.py +@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode): + ) + + @classmethod +- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs): ++ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs): + """Creates a new instance with `step=0` and initialized `opt_state`.""" + # downcast optimizer state to bf16 if mixed-precision training + opt_state = tx.init(to_dtype(params)) if tx is not None else None + return cls( +- step=0, ++ step=step, + apply_fn=apply_fn, + params=params, + tx=tx, +@@ -1339,6 +1339,7 @@ def main(): + + # Setup train state + state = MixedPrecisionTrainState.create( ++ step=data_args.skip_steps, + apply_fn=model.__call__, + get_attention_mask_fn=model._get_feature_vector_attention_mask, + params=model.params, +@@ -1517,14 +1518,13 @@ def main(): + if training_args.do_train: + # ======================== Training ================================ + train_start = time.time() ++ # Create sampling rng ++ rng, input_rng = jax.random.split(rng) + + if epoch < skip_epochs: + logger.info(f"Skipping epoch {epoch + 1}") + continue + +- # Create sampling rng +- rng, input_rng = jax.random.split(rng) +- + # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length + train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) +diff --git a/special_tokens_map.json b/special_tokens_map.json +index 218961f..0d13bc3 100644 +--- a/special_tokens_map.json ++++ b/special_tokens_map.json +@@ -399,6 +399,90 @@ + "rstrip": false, + "single_word": false + }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, ++ { ++ "content": "", ++ "lstrip": false, ++ "normalized": true, ++ "rstrip": false, ++ "single_word": false ++ }, + { + "content": "", + "lstrip": false, +diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log +index 23926ef..737c280 120000 +--- a/wandb/debug-internal.log ++++ b/wandb/debug-internal.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug-internal.log +\ No newline at end of file ++run-20220811_101752-mzjvp6ho/logs/debug-internal.log +\ No newline at end of file +diff --git a/wandb/debug.log b/wandb/debug.log +index 279853d..fcc4539 120000 +--- a/wandb/debug.log ++++ b/wandb/debug.log +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4/logs/debug.log +\ No newline at end of file ++run-20220811_101752-mzjvp6ho/logs/debug.log +\ No newline at end of file +diff --git a/wandb/latest-run b/wandb/latest-run +index f069a7a..e1726b4 120000 +--- a/wandb/latest-run ++++ b/wandb/latest-run +@@ -1 +1 @@ +-run-20220805_230151-2y71vcu4 +\ No newline at end of file ++run-20220811_101752-mzjvp6ho +\ No newline at end of file diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/output.log b/wandb/run-20220811_101752-mzjvp6ho/files/output.log new file mode 100644 index 0000000000000000000000000000000000000000..6ae981e61487952ebcbe0d5672a389f201beb87f --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/output.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff4e95f5d6abb04acd6be9e6e65bfa8ca83a77e617cd7115444eb1c34c3cfca1 +size 200531 diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/requirements.txt b/wandb/run-20220811_101752-mzjvp6ho/files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0273eb6554b8538eecc3cb9f4a47c988bd3d0dd --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/requirements.txt @@ -0,0 +1,158 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +appdirs==1.4.4 +astunparse==1.6.3 +async-timeout==4.0.2 +attrs==21.4.0 +audioread==2.1.9 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.1 +charset-normalizer==2.0.10 +chex==0.1.3 +click==8.0.3 +cloud-tpu-client==0.10 +cloud-tpu-profiler==2.4.0 +clu==0.0.6 +colorama==0.4.5 +commonmark==0.9.1 +configparser==5.2.0 +contextlib2==21.6.0 +cycler==0.11.0 +datasets==2.4.0 +decorator==5.1.0 +dill==0.3.4 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +etils==0.6.0 +exceptiongroup==1.0.0rc8 +filelock==3.4.2 +flatbuffers==2.0 +flax==0.5.3 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +gast==0.4.0 +gitdb==4.0.9 +gitpython==3.1.26 +google-api-core==1.31.5 +google-api-python-client==1.8.0 +google-auth-httplib2==0.1.0 +google-auth-oauthlib==0.4.6 +google-auth==2.3.3 +google-pasta==0.2.0 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h5py==3.6.0 +httplib2==0.20.2 +huggingface-hub==0.2.1 +hypothesis==6.53.0 +idna==3.3 +importlib-metadata==4.10.0 +importlib-resources==5.4.0 +ipython==7.31.0 +jax==0.3.15 +jaxlib==0.3.15 +jedi==0.18.1 +jiwer==2.3.0 +joblib==1.1.0 +keras-preprocessing==1.1.2 +keras==2.7.0 +kiwisolver==1.3.2 +libclang==12.0.0 +librosa==0.9.2 +libtpu-nightly==0.1.dev20220722 +llvmlite==0.39.0 +markdown==3.3.6 +matplotlib-inline==0.1.3 +matplotlib==3.5.1 +ml-collections==0.1.0 +msgpack==1.0.3 +multidict==5.2.0 +multiprocess==0.70.12.2 +numba==0.56.0 +numpy==1.22.0 +oauth2client==4.1.3 +oauthlib==3.1.1 +opt-einsum==3.3.0 +optax==0.1.3 +packaging==21.3 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.0.0 +pip==22.2.2 +pkg-resources==0.0.0 +pooch==1.6.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.1 +psutil==5.9.0 +ptyprocess==0.7.0 +pyarrow==6.0.1 +pyasn1-modules==0.2.8 +pyasn1==0.4.8 +pycparser==2.21 +pyctcdecode==0.4.0 +pygments==2.11.1 +pygtrie==2.5.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +python-levenshtein==0.12.2 +pytz==2021.3 +pyyaml==6.0 +regex==2021.11.10 +requests-oauthlib==1.3.0 +requests==2.27.0 +resampy==0.3.1 +responses==0.18.0 +rich==11.2.0 +rsa==4.8 +sacremoses==0.0.46 +scikit-learn==1.1.1 +scipy==1.7.3 +sentry-sdk==1.5.2 +setuptools==44.0.0 +shortuuid==1.0.8 +six==1.16.0 +smmap==5.0.0 +sortedcontainers==2.4.0 +soundfile==0.10.3.post1 +sox==1.4.1 +subprocess32==3.5.4 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboard==2.7.0 +tensorflow-cpu==2.7.0 +tensorflow-datasets==4.4.0 +tensorflow-estimator==2.7.0 +tensorflow-io-gcs-filesystem==0.23.1 +tensorflow-metadata==1.5.0 +tensorflow==2.7.0 +tensorstore==0.1.21 +termcolor==1.1.0 +threadpoolctl==3.1.0 +tokenizers==0.11.2 +toolz==0.11.2 +torch==1.12.0 +torchaudio==0.12.0+cpu +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.21.0 +typing-extensions==4.3.0 +uritemplate==3.0.1 +urllib3==1.26.7 +wandb==0.12.9 +wcwidth==0.2.5 +werkzeug==2.0.2 +wheel==0.37.1 +wrapt==1.13.3 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zipp==3.7.0 \ No newline at end of file diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/wandb-metadata.json b/wandb/run-20220811_101752-mzjvp6ho/files/wandb-metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..17de88c8c5d4e35bcfebe3f8fb4a820f24f1f6e9 --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/wandb-metadata.json @@ -0,0 +1,67 @@ +{ + "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29", + "python": "3.8.10", + "heartbeatAt": "2022-08-11T10:17:55.885456", + "startedAt": "2022-08-11T10:17:52.376096", + "docker": null, + "cpu_count": 96, + "cuda": null, + "args": [ + "--model_name_or_path=facebook/wav2vec2-xls-r-1b", + "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "--tokenizer_name=./", + "--output_dir=./", + "--overwrite_output_dir", + "--num_train_epochs=40", + "--per_device_train_batch_size=2", + "--per_device_eval_batch_size=2", + "--gradient_accumulation_steps=1", + "--precision=full_mixed", + "--matmul_precision=bfloat16", + "--multisteps", + "--learning_rate=2e-5", + "--warmup_steps=2000", + "--length_column_name=input_length", + "--evaluation_strategy=steps", + "--text_column_name=text", + "--save_steps=5000", + "--eval_steps=5000", + "--logging_steps=100", + "--layerdrop=0.041", + "--attention_dropout=0.094", + "--activation_dropout=0.055", + "--hidden_dropout=0.047", + "--save_total_limit=5", + "--freeze_feature_encoder", + "--feat_proj_dropout=0.04", + "--mask_time_prob=0.082", + "--mask_time_length=10", + "--mask_feature_prob=0.25", + "--mask_feature_length=64", + "--gradient_checkpointing", + "--min_duration_in_seconds=0.5", + "--max_duration_in_seconds=25.0", + "--use_auth_token", + "--seed=42", + "--group_by_length", + "--do_train", + "--do_eval", + "--push_to_hub", + "--preprocessing_num_workers=32", + "--ctc_zero_infinity", + "--wandb_project=wav2vec2", + "--wandb_name=wav2vec2-1b-npsc-nst-tpu" + ], + "state": "running", + "program": "run_flax_speech_recognition_ctc.py", + "codePath": "run_flax_speech_recognition_ctc.py", + "git": { + "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu", + "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745" + }, + "email": "versae@gmail.com", + "root": "/data/wav2vec2-1b-npsc-nst-tpu", + "host": "t1v-n-eedfb410-w-0", + "username": "javierr", + "executable": "/data/flax/bin/python" +} diff --git a/wandb/run-20220811_101752-mzjvp6ho/files/wandb-summary.json b/wandb/run-20220811_101752-mzjvp6ho/files/wandb-summary.json new file mode 100644 index 0000000000000000000000000000000000000000..5ed89343abe5a09c67166be9b552a69b308af3aa --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/files/wandb-summary.json @@ -0,0 +1 @@ +{"train/grad_norm": 14.625, "layer_grad_norm/": {"lm_head": {"bias": 0.1162109375, "kernel": 2.265625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.08837890625, "scale": 0.0810546875}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.00031280517578125, "kernel": 0.23828125}, "out_proj": {"bias": 0.12890625, "kernel": 1.1953125}, "q_proj": {"bias": 0.0208740234375, "kernel": 0.30859375}, "v_proj": {"bias": 0.09521484375, "kernel": 0.828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.1474609375, "kernel": 2.03125}, "output_dense": {"bias": 0.06640625, "kernel": 1.6875}}, "final_layer_norm": {"bias": 0.345703125, "scale": 0.46484375}, "layer_norm": {"bias": 0.17578125, "scale": 0.345703125}}, "1": {"attention": {"k_proj": {"bias": 0.00018024444580078125, "kernel": 0.1396484375}, "out_proj": {"bias": 0.0791015625, "kernel": 0.85546875}, "q_proj": {"bias": 0.0125732421875, "kernel": 0.142578125}, "v_proj": {"bias": 0.1123046875, "kernel": 0.7109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.09375, "kernel": 1.3515625}, "output_dense": {"bias": 0.07373046875, "kernel": 1.1875}}, "final_layer_norm": {"bias": 0.16015625, "scale": 0.1328125}, "layer_norm": {"bias": 0.169921875, "scale": 0.169921875}}, "10": {"attention": {"k_proj": {"bias": 0.00010633468627929688, "kernel": 0.3125}, "out_proj": {"bias": 0.056640625, "kernel": 0.68359375}, "q_proj": {"bias": 0.0185546875, "kernel": 0.322265625}, "v_proj": {"bias": 0.08447265625, "kernel": 0.78515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0810546875, "kernel": 1.25}, "output_dense": {"bias": 0.05517578125, "kernel": 0.9765625}}, "final_layer_norm": {"bias": 0.1318359375, "scale": 0.12353515625}, "layer_norm": {"bias": 0.154296875, "scale": 0.1044921875}}, "11": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.31640625}, "out_proj": {"bias": 0.053466796875, "kernel": 0.8125}, "q_proj": {"bias": 0.0185546875, "kernel": 0.30078125}, "v_proj": {"bias": 0.0869140625, "kernel": 0.9140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0771484375, "kernel": 1.234375}, "output_dense": {"bias": 0.052001953125, "kernel": 0.88671875}}, "final_layer_norm": {"bias": 0.12255859375, "scale": 0.10205078125}, "layer_norm": {"bias": 0.1484375, "scale": 0.1435546875}}, "12": {"attention": {"k_proj": {"bias": 9.822845458984375e-05, "kernel": 0.26953125}, "out_proj": {"bias": 0.0537109375, "kernel": 0.6875}, "q_proj": {"bias": 0.01470947265625, "kernel": 0.259765625}, "v_proj": {"bias": 0.08154296875, "kernel": 0.7890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0830078125, "kernel": 1.265625}, "output_dense": {"bias": 0.0498046875, "kernel": 0.875}}, "final_layer_norm": {"bias": 0.1337890625, "scale": 0.130859375}, "layer_norm": {"bias": 0.1259765625, "scale": 0.1044921875}}, "13": {"attention": {"k_proj": {"bias": 0.0001811981201171875, "kernel": 0.396484375}, "out_proj": {"bias": 0.05224609375, "kernel": 0.76953125}, "q_proj": {"bias": 0.023193359375, "kernel": 0.365234375}, "v_proj": {"bias": 0.08984375, "kernel": 0.9609375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0810546875, "kernel": 1.1796875}, "output_dense": {"bias": 0.05126953125, "kernel": 0.87109375}}, "final_layer_norm": {"bias": 0.130859375, "scale": 0.109375}, "layer_norm": {"bias": 0.1279296875, "scale": 0.09765625}}, "14": {"attention": {"k_proj": {"bias": 0.00020503997802734375, "kernel": 0.244140625}, "out_proj": {"bias": 0.050048828125, "kernel": 0.74609375}, "q_proj": {"bias": 0.0150146484375, "kernel": 0.2421875}, "v_proj": {"bias": 0.076171875, "kernel": 0.84765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.07861328125, "kernel": 1.1484375}, "output_dense": {"bias": 0.0498046875, "kernel": 0.8984375}}, "final_layer_norm": {"bias": 0.12890625, "scale": 0.0966796875}, "layer_norm": {"bias": 0.10986328125, "scale": 0.123046875}}, "15": {"attention": {"k_proj": {"bias": 0.000194549560546875, "kernel": 0.30078125}, "out_proj": {"bias": 0.05126953125, "kernel": 0.97265625}, "q_proj": {"bias": 0.0179443359375, "kernel": 0.2734375}, "v_proj": {"bias": 0.0810546875, "kernel": 0.9140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0751953125, "kernel": 1.0703125}, "output_dense": {"bias": 0.0517578125, "kernel": 0.9296875}}, "final_layer_norm": {"bias": 0.12060546875, "scale": 0.099609375}, "layer_norm": {"bias": 0.1123046875, "scale": 0.11376953125}}, "16": {"attention": {"k_proj": {"bias": 0.00013256072998046875, "kernel": 0.265625}, "out_proj": {"bias": 0.053466796875, "kernel": 0.6953125}, "q_proj": {"bias": 0.01611328125, "kernel": 0.25}, "v_proj": {"bias": 0.078125, "kernel": 0.78125}}, "feed_forward": {"intermediate_dense": {"bias": 0.07421875, "kernel": 1.078125}, "output_dense": {"bias": 0.0517578125, "kernel": 0.95703125}}, "final_layer_norm": {"bias": 0.11572265625, "scale": 0.0986328125}, "layer_norm": {"bias": 0.1142578125, "scale": 0.1142578125}}, "17": {"attention": {"k_proj": {"bias": 0.000118255615234375, "kernel": 0.259765625}, "out_proj": {"bias": 0.058349609375, "kernel": 0.6328125}, "q_proj": {"bias": 0.0164794921875, "kernel": 0.26171875}, "v_proj": {"bias": 0.0869140625, "kernel": 0.7578125}}, "feed_forward": {"intermediate_dense": {"bias": 0.080078125, "kernel": 1.140625}, "output_dense": {"bias": 0.056396484375, "kernel": 0.984375}}, "final_layer_norm": {"bias": 0.1279296875, "scale": 0.1171875}, "layer_norm": {"bias": 0.1298828125, "scale": 0.125}}, "18": {"attention": {"k_proj": {"bias": 0.000171661376953125, "kernel": 0.30078125}, "out_proj": {"bias": 0.05712890625, "kernel": 0.8125}, "q_proj": {"bias": 0.0189208984375, "kernel": 0.298828125}, "v_proj": {"bias": 0.0849609375, "kernel": 0.7890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.078125, "kernel": 1.171875}, "output_dense": {"bias": 0.0556640625, "kernel": 1.0546875}}, "final_layer_norm": {"bias": 0.12451171875, "scale": 0.1123046875}, "layer_norm": {"bias": 0.126953125, "scale": 0.1591796875}}, "19": {"attention": {"k_proj": {"bias": 0.00013446807861328125, "kernel": 0.26171875}, "out_proj": {"bias": 0.058349609375, "kernel": 0.625}, "q_proj": {"bias": 0.015380859375, "kernel": 0.27734375}, "v_proj": {"bias": 0.083984375, "kernel": 0.7421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.07763671875, "kernel": 1.21875}, "output_dense": {"bias": 0.05712890625, "kernel": 1.09375}}, "final_layer_norm": {"bias": 0.123046875, "scale": 0.1015625}, "layer_norm": {"bias": 0.1201171875, "scale": 0.12353515625}}, "2": {"attention": {"k_proj": {"bias": 0.0001544952392578125, "kernel": 0.19921875}, "out_proj": {"bias": 0.08837890625, "kernel": 0.90625}, "q_proj": {"bias": 0.0166015625, "kernel": 0.2001953125}, "v_proj": {"bias": 0.1416015625, "kernel": 1.0078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.10400390625, "kernel": 1.703125}, "output_dense": {"bias": 0.0771484375, "kernel": 1.296875}}, "final_layer_norm": {"bias": 0.177734375, "scale": 0.1396484375}, "layer_norm": {"bias": 0.193359375, "scale": 0.18359375}}, "20": {"attention": {"k_proj": {"bias": 6.818771362304688e-05, "kernel": 0.2314453125}, "out_proj": {"bias": 0.0625, "kernel": 0.42578125}, "q_proj": {"bias": 0.0142822265625, "kernel": 0.275390625}, "v_proj": {"bias": 0.08642578125, "kernel": 0.55078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0830078125, "kernel": 1.328125}, "output_dense": {"bias": 0.061279296875, "kernel": 1.109375}}, "final_layer_norm": {"bias": 0.1328125, "scale": 0.11669921875}, "layer_norm": {"bias": 0.123046875, "scale": 0.099609375}}, "21": {"attention": {"k_proj": {"bias": 0.00012493133544921875, "kernel": 0.24609375}, "out_proj": {"bias": 0.0634765625, "kernel": 0.64453125}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.28125}, "v_proj": {"bias": 0.08837890625, "kernel": 0.7265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.08154296875, "kernel": 1.328125}, "output_dense": {"bias": 0.06201171875, "kernel": 1.0859375}}, "final_layer_norm": {"bias": 0.12890625, "scale": 0.134765625}, "layer_norm": {"bias": 0.1181640625, "scale": 0.1123046875}}, "22": {"attention": {"k_proj": {"bias": 8.249282836914062e-05, "kernel": 0.265625}, "out_proj": {"bias": 0.068359375, "kernel": 0.5703125}, "q_proj": {"bias": 0.0174560546875, "kernel": 0.296875}, "v_proj": {"bias": 0.09326171875, "kernel": 0.671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.087890625, "kernel": 1.375}, "output_dense": {"bias": 0.0673828125, "kernel": 1.09375}}, "final_layer_norm": {"bias": 0.142578125, "scale": 0.162109375}, "layer_norm": {"bias": 0.1357421875, "scale": 0.169921875}}, "23": {"attention": {"k_proj": {"bias": 0.000270843505859375, "kernel": 0.34765625}, "out_proj": {"bias": 0.072265625, "kernel": 1.0546875}, "q_proj": {"bias": 0.020263671875, "kernel": 0.3515625}, "v_proj": {"bias": 0.10595703125, "kernel": 1.109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0859375, "kernel": 1.359375}, "output_dense": {"bias": 0.0703125, "kernel": 1.0703125}}, "final_layer_norm": {"bias": 0.1357421875, "scale": 0.1533203125}, "layer_norm": {"bias": 0.15234375, "scale": 0.1376953125}}, "24": {"attention": {"k_proj": {"bias": 0.000186920166015625, "kernel": 0.333984375}, "out_proj": {"bias": 0.0654296875, "kernel": 0.8046875}, "q_proj": {"bias": 0.020751953125, "kernel": 0.34375}, "v_proj": {"bias": 0.10107421875, "kernel": 0.875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0859375, "kernel": 1.3984375}, "output_dense": {"bias": 0.0634765625, "kernel": 1.046875}}, "final_layer_norm": {"bias": 0.138671875, "scale": 0.1640625}, "layer_norm": {"bias": 0.1591796875, "scale": 0.1494140625}}, "25": {"attention": {"k_proj": {"bias": 0.00018310546875, "kernel": 0.3046875}, "out_proj": {"bias": 0.0673828125, "kernel": 0.90625}, "q_proj": {"bias": 0.02001953125, "kernel": 0.3046875}, "v_proj": {"bias": 0.1025390625, "kernel": 0.96875}}, "feed_forward": {"intermediate_dense": {"bias": 0.087890625, "kernel": 1.4296875}, "output_dense": {"bias": 0.06591796875, "kernel": 1.0625}}, "final_layer_norm": {"bias": 0.146484375, "scale": 0.125}, "layer_norm": {"bias": 0.154296875, "scale": 0.1953125}}, "26": {"attention": {"k_proj": {"bias": 0.00016880035400390625, "kernel": 0.390625}, "out_proj": {"bias": 0.0654296875, "kernel": 0.87890625}, "q_proj": {"bias": 0.0263671875, "kernel": 0.412109375}, "v_proj": {"bias": 0.1025390625, "kernel": 0.953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.08740234375, "kernel": 1.328125}, "output_dense": {"bias": 0.0673828125, "kernel": 1.078125}}, "final_layer_norm": {"bias": 0.14453125, "scale": 0.1181640625}, "layer_norm": {"bias": 0.146484375, "scale": 0.265625}}, "27": {"attention": {"k_proj": {"bias": 0.000225067138671875, "kernel": 0.4140625}, "out_proj": {"bias": 0.06103515625, "kernel": 1.0078125}, "q_proj": {"bias": 0.0250244140625, "kernel": 0.400390625}, "v_proj": {"bias": 0.0986328125, "kernel": 1.015625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0859375, "kernel": 1.3046875}, "output_dense": {"bias": 0.06201171875, "kernel": 1.046875}}, "final_layer_norm": {"bias": 0.14453125, "scale": 0.16015625}, "layer_norm": {"bias": 0.15625, "scale": 0.1220703125}}, "28": {"attention": {"k_proj": {"bias": 0.000209808349609375, "kernel": 0.3671875}, "out_proj": {"bias": 0.0556640625, "kernel": 0.99609375}, "q_proj": {"bias": 0.023681640625, "kernel": 0.40234375}, "v_proj": {"bias": 0.0908203125, "kernel": 1.015625}}, "feed_forward": {"intermediate_dense": {"bias": 0.078125, "kernel": 1.265625}, "output_dense": {"bias": 0.05810546875, "kernel": 1.03125}}, "final_layer_norm": {"bias": 0.1298828125, "scale": 0.1328125}, "layer_norm": {"bias": 0.1484375, "scale": 0.1767578125}}, "29": {"attention": {"k_proj": {"bias": 0.00017070770263671875, "kernel": 0.365234375}, "out_proj": {"bias": 0.052490234375, "kernel": 0.8046875}, "q_proj": {"bias": 0.0186767578125, "kernel": 0.3984375}, "v_proj": {"bias": 0.08544921875, "kernel": 0.890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0751953125, "kernel": 1.34375}, "output_dense": {"bias": 0.050048828125, "kernel": 0.9921875}}, "final_layer_norm": {"bias": 0.11572265625, "scale": 0.1142578125}, "layer_norm": {"bias": 0.1474609375, "scale": 0.1005859375}}, "3": {"attention": {"k_proj": {"bias": 0.000278472900390625, "kernel": 0.326171875}, "out_proj": {"bias": 0.0859375, "kernel": 1.2265625}, "q_proj": {"bias": 0.023193359375, "kernel": 0.31640625}, "v_proj": {"bias": 0.1337890625, "kernel": 1.234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.107421875, "kernel": 1.75}, "output_dense": {"bias": 0.0771484375, "kernel": 1.265625}}, "final_layer_norm": {"bias": 0.181640625, "scale": 0.169921875}, "layer_norm": {"bias": 0.189453125, "scale": 0.1748046875}}, "30": {"attention": {"k_proj": {"bias": 0.0001850128173828125, "kernel": 0.396484375}, "out_proj": {"bias": 0.04931640625, "kernel": 0.8203125}, "q_proj": {"bias": 0.02099609375, "kernel": 0.421875}, "v_proj": {"bias": 0.078125, "kernel": 0.921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.06982421875, "kernel": 1.28125}, "output_dense": {"bias": 0.04736328125, "kernel": 0.84765625}}, "final_layer_norm": {"bias": 0.1083984375, "scale": 0.11376953125}, "layer_norm": {"bias": 0.10888671875, "scale": 0.1162109375}}, "31": {"attention": {"k_proj": {"bias": 0.00018310546875, "kernel": 0.40234375}, "out_proj": {"bias": 0.0439453125, "kernel": 0.7265625}, "q_proj": {"bias": 0.02294921875, "kernel": 0.42578125}, "v_proj": {"bias": 0.068359375, "kernel": 0.8203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0615234375, "kernel": 1.125}, "output_dense": {"bias": 0.04150390625, "kernel": 0.7421875}}, "final_layer_norm": {"bias": 0.0966796875, "scale": 0.1015625}, "layer_norm": {"bias": 0.1064453125, "scale": 0.1787109375}}, "32": {"attention": {"k_proj": {"bias": 0.00019073486328125, "kernel": 0.33203125}, "out_proj": {"bias": 0.0390625, "kernel": 0.5703125}, "q_proj": {"bias": 0.0185546875, "kernel": 0.361328125}, "v_proj": {"bias": 0.056640625, "kernel": 0.65234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.052978515625, "kernel": 1.0}, "output_dense": {"bias": 0.032958984375, "kernel": 0.625}}, "final_layer_norm": {"bias": 0.0888671875, "scale": 0.08251953125}, "layer_norm": {"bias": 0.08349609375, "scale": 0.09716796875}}, "33": {"attention": {"k_proj": {"bias": 0.0001430511474609375, "kernel": 0.3203125}, "out_proj": {"bias": 0.03076171875, "kernel": 0.55078125}, "q_proj": {"bias": 0.015869140625, "kernel": 0.328125}, "v_proj": {"bias": 0.048583984375, "kernel": 0.60546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.04345703125, "kernel": 0.79296875}, "output_dense": {"bias": 0.02783203125, "kernel": 0.5546875}}, "final_layer_norm": {"bias": 0.0732421875, "scale": 0.0625}, "layer_norm": {"bias": 0.06787109375, "scale": 0.0791015625}}, "34": {"attention": {"k_proj": {"bias": 0.00013637542724609375, "kernel": 0.236328125}, "out_proj": {"bias": 0.024169921875, "kernel": 0.5}, "q_proj": {"bias": 0.01171875, "kernel": 0.23828125}, "v_proj": {"bias": 0.03759765625, "kernel": 0.51953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.03466796875, "kernel": 0.62109375}, "output_dense": {"bias": 0.0228271484375, "kernel": 0.484375}}, "final_layer_norm": {"bias": 0.05615234375, "scale": 0.055908203125}, "layer_norm": {"bias": 0.0546875, "scale": 0.044921875}}, "35": {"attention": {"k_proj": {"bias": 0.0001773834228515625, "kernel": 0.1640625}, "out_proj": {"bias": 0.020751953125, "kernel": 0.48828125}, "q_proj": {"bias": 0.0081787109375, "kernel": 0.166015625}, "v_proj": {"bias": 0.02880859375, "kernel": 0.4453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0299072265625, "kernel": 0.5234375}, "output_dense": {"bias": 0.0208740234375, "kernel": 0.44140625}}, "final_layer_norm": {"bias": 0.04931640625, "scale": 0.04541015625}, "layer_norm": {"bias": 0.04541015625, "scale": 0.04248046875}}, "36": {"attention": {"k_proj": {"bias": 9.34600830078125e-05, "kernel": 0.1865234375}, "out_proj": {"bias": 0.01953125, "kernel": 0.3984375}, "q_proj": {"bias": 0.0084228515625, "kernel": 0.1767578125}, "v_proj": {"bias": 0.0267333984375, "kernel": 0.357421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02880859375, "kernel": 0.48046875}, "output_dense": {"bias": 0.0196533203125, "kernel": 0.38671875}}, "final_layer_norm": {"bias": 0.04638671875, "scale": 0.046630859375}, "layer_norm": {"bias": 0.043212890625, "scale": 0.03759765625}}, "37": {"attention": {"k_proj": {"bias": 9.34600830078125e-05, "kernel": 0.171875}, "out_proj": {"bias": 0.0179443359375, "kernel": 0.41015625}, "q_proj": {"bias": 0.0089111328125, "kernel": 0.1689453125}, "v_proj": {"bias": 0.026123046875, "kernel": 0.36328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.025390625, "kernel": 0.4453125}, "output_dense": {"bias": 0.017578125, "kernel": 0.373046875}}, "final_layer_norm": {"bias": 0.041748046875, "scale": 0.03662109375}, "layer_norm": {"bias": 0.0498046875, "scale": 0.03857421875}}, "38": {"attention": {"k_proj": {"bias": 7.677078247070312e-05, "kernel": 0.154296875}, "out_proj": {"bias": 0.0159912109375, "kernel": 0.359375}, "q_proj": {"bias": 0.0069580078125, "kernel": 0.1435546875}, "v_proj": {"bias": 0.02294921875, "kernel": 0.322265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0220947265625, "kernel": 0.41796875}, "output_dense": {"bias": 0.0157470703125, "kernel": 0.375}}, "final_layer_norm": {"bias": 0.035888671875, "scale": 0.037353515625}, "layer_norm": {"bias": 0.04052734375, "scale": 0.0299072265625}}, "39": {"attention": {"k_proj": {"bias": 6.389617919921875e-05, "kernel": 0.1162109375}, "out_proj": {"bias": 0.01385498046875, "kernel": 0.35546875}, "q_proj": {"bias": 0.005706787109375, "kernel": 0.1181640625}, "v_proj": {"bias": 0.0196533203125, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0213623046875, "kernel": 0.416015625}, "output_dense": {"bias": 0.013671875, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.03466796875, "scale": 0.03173828125}, "layer_norm": {"bias": 0.03271484375, "scale": 0.0341796875}}, "4": {"attention": {"k_proj": {"bias": 0.000324249267578125, "kernel": 0.38671875}, "out_proj": {"bias": 0.0810546875, "kernel": 1.40625}, "q_proj": {"bias": 0.0252685546875, "kernel": 0.38671875}, "v_proj": {"bias": 0.12109375, "kernel": 1.359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.10693359375, "kernel": 1.65625}, "output_dense": {"bias": 0.0751953125, "kernel": 1.265625}}, "final_layer_norm": {"bias": 0.1796875, "scale": 0.19921875}, "layer_norm": {"bias": 0.171875, "scale": 0.21484375}}, "40": {"attention": {"k_proj": {"bias": 3.62396240234375e-05, "kernel": 0.1015625}, "out_proj": {"bias": 0.01324462890625, "kernel": 0.30078125}, "q_proj": {"bias": 0.004486083984375, "kernel": 0.1015625}, "v_proj": {"bias": 0.019287109375, "kernel": 0.291015625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0205078125, "kernel": 0.4140625}, "output_dense": {"bias": 0.0133056640625, "kernel": 0.34375}}, "final_layer_norm": {"bias": 0.036865234375, "scale": 0.04443359375}, "layer_norm": {"bias": 0.0301513671875, "scale": 0.0322265625}}, "41": {"attention": {"k_proj": {"bias": 5.269050598144531e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.01177978515625, "kernel": 0.294921875}, "q_proj": {"bias": 0.0052490234375, "kernel": 0.1220703125}, "v_proj": {"bias": 0.0185546875, "kernel": 0.3359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0181884765625, "kernel": 0.40625}, "output_dense": {"bias": 0.011962890625, "kernel": 0.34765625}}, "final_layer_norm": {"bias": 0.03369140625, "scale": 0.038818359375}, "layer_norm": {"bias": 0.03076171875, "scale": 0.05810546875}}, "42": {"attention": {"k_proj": {"bias": 3.1948089599609375e-05, "kernel": 0.060546875}, "out_proj": {"bias": 0.01171875, "kernel": 0.22265625}, "q_proj": {"bias": 0.00286865234375, "kernel": 0.06640625}, "v_proj": {"bias": 0.01531982421875, "kernel": 0.23828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.016357421875, "kernel": 0.38671875}, "output_dense": {"bias": 0.0123291015625, "kernel": 0.30078125}}, "final_layer_norm": {"bias": 0.0264892578125, "scale": 0.0289306640625}, "layer_norm": {"bias": 0.0242919921875, "scale": 0.037109375}}, "43": {"attention": {"k_proj": {"bias": 1.6927719116210938e-05, "kernel": 0.04736328125}, "out_proj": {"bias": 0.0125732421875, "kernel": 0.189453125}, "q_proj": {"bias": 0.0025787353515625, "kernel": 0.0517578125}, "v_proj": {"bias": 0.01507568359375, "kernel": 0.208984375}}, "feed_forward": {"intermediate_dense": {"bias": 0.017822265625, "kernel": 0.435546875}, "output_dense": {"bias": 0.013427734375, "kernel": 0.298828125}}, "final_layer_norm": {"bias": 0.029541015625, "scale": 0.031982421875}, "layer_norm": {"bias": 0.0257568359375, "scale": 0.03759765625}}, "44": {"attention": {"k_proj": {"bias": 1.5497207641601562e-05, "kernel": 0.050048828125}, "out_proj": {"bias": 0.01361083984375, "kernel": 0.212890625}, "q_proj": {"bias": 0.0026092529296875, "kernel": 0.05419921875}, "v_proj": {"bias": 0.0162353515625, "kernel": 0.2294921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.016357421875, "kernel": 0.43359375}, "output_dense": {"bias": 0.01422119140625, "kernel": 0.27734375}}, "final_layer_norm": {"bias": 0.026123046875, "scale": 0.02734375}, "layer_norm": {"bias": 0.0303955078125, "scale": 0.028076171875}}, "45": {"attention": {"k_proj": {"bias": 1.7642974853515625e-05, "kernel": 0.0546875}, "out_proj": {"bias": 0.01385498046875, "kernel": 0.22265625}, "q_proj": {"bias": 0.00286865234375, "kernel": 0.05810546875}, "v_proj": {"bias": 0.017333984375, "kernel": 0.236328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.014404296875, "kernel": 0.38671875}, "output_dense": {"bias": 0.0142822265625, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.022705078125, "scale": 0.0267333984375}, "layer_norm": {"bias": 0.0390625, "scale": 0.0303955078125}}, "46": {"attention": {"k_proj": {"bias": 2.1696090698242188e-05, "kernel": 0.064453125}, "out_proj": {"bias": 0.0135498046875, "kernel": 0.216796875}, "q_proj": {"bias": 0.003204345703125, "kernel": 0.0615234375}, "v_proj": {"bias": 0.0185546875, "kernel": 0.255859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0115966796875, "kernel": 0.296875}, "output_dense": {"bias": 0.01300048828125, "kernel": 0.232421875}}, "final_layer_norm": {"bias": 0.0194091796875, "scale": 0.0263671875}, "layer_norm": {"bias": 0.048095703125, "scale": 0.034423828125}}, "47": {"attention": {"k_proj": {"bias": 2.1696090698242188e-05, "kernel": 0.0625}, "out_proj": {"bias": 0.01336669921875, "kernel": 0.177734375}, "q_proj": {"bias": 0.003143310546875, "kernel": 0.052734375}, "v_proj": {"bias": 0.021240234375, "kernel": 0.26171875}}, "feed_forward": {"intermediate_dense": {"bias": 0.010009765625, "kernel": 0.2041015625}, "output_dense": {"bias": 0.01251220703125, "kernel": 0.1904296875}}, "final_layer_norm": {"bias": 0.019287109375, "scale": 0.01806640625}, "layer_norm": {"bias": 0.0576171875, "scale": 0.0390625}}, "5": {"attention": {"k_proj": {"bias": 0.00016117095947265625, "kernel": 0.33984375}, "out_proj": {"bias": 0.0810546875, "kernel": 0.9140625}, "q_proj": {"bias": 0.0216064453125, "kernel": 0.3515625}, "v_proj": {"bias": 0.1201171875, "kernel": 0.9765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.103515625, "kernel": 1.53125}, "output_dense": {"bias": 0.07861328125, "kernel": 1.1796875}}, "final_layer_norm": {"bias": 0.17578125, "scale": 0.1435546875}, "layer_norm": {"bias": 0.17578125, "scale": 0.1220703125}}, "6": {"attention": {"k_proj": {"bias": 0.000225067138671875, "kernel": 0.404296875}, "out_proj": {"bias": 0.07568359375, "kernel": 1.1953125}, "q_proj": {"bias": 0.02392578125, "kernel": 0.39453125}, "v_proj": {"bias": 0.1240234375, "kernel": 1.2578125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0986328125, "kernel": 1.5390625}, "output_dense": {"bias": 0.072265625, "kernel": 1.1171875}}, "final_layer_norm": {"bias": 0.166015625, "scale": 0.146484375}, "layer_norm": {"bias": 0.189453125, "scale": 0.1708984375}}, "7": {"attention": {"k_proj": {"bias": 0.000263214111328125, "kernel": 0.41015625}, "out_proj": {"bias": 0.0703125, "kernel": 1.171875}, "q_proj": {"bias": 0.0263671875, "kernel": 0.3984375}, "v_proj": {"bias": 0.1123046875, "kernel": 1.2109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.09228515625, "kernel": 1.5}, "output_dense": {"bias": 0.0654296875, "kernel": 1.09375}}, "final_layer_norm": {"bias": 0.15234375, "scale": 0.142578125}, "layer_norm": {"bias": 0.185546875, "scale": 0.1572265625}}, "8": {"attention": {"k_proj": {"bias": 0.0002307891845703125, "kernel": 0.359375}, "out_proj": {"bias": 0.06494140625, "kernel": 1.0}, "q_proj": {"bias": 0.021240234375, "kernel": 0.35546875}, "v_proj": {"bias": 0.10107421875, "kernel": 1.0546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0888671875, "kernel": 1.4453125}, "output_dense": {"bias": 0.06298828125, "kernel": 1.0625}}, "final_layer_norm": {"bias": 0.1435546875, "scale": 0.1337890625}, "layer_norm": {"bias": 0.15625, "scale": 0.1865234375}}, "9": {"attention": {"k_proj": {"bias": 0.0002803802490234375, "kernel": 0.427734375}, "out_proj": {"bias": 0.057373046875, "kernel": 1.2109375}, "q_proj": {"bias": 0.0235595703125, "kernel": 0.4140625}, "v_proj": {"bias": 0.09228515625, "kernel": 1.28125}}, "feed_forward": {"intermediate_dense": {"bias": 0.080078125, "kernel": 1.3203125}, "output_dense": {"bias": 0.056640625, "kernel": 1.0546875}}, "final_layer_norm": {"bias": 0.12890625, "scale": 0.12890625}, "layer_norm": {"bias": 0.1494140625, "scale": 0.1484375}}}, "pos_conv_embed": {"conv": {"bias": 0.13671875, "weight_g": 0.095703125, "weight_v": 1.0625}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.341796875, "scale": 0.5}, "projection": {"bias": 0.2060546875, "kernel": 4.0625}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.022005891427397728, "kernel": 4.73419713973999}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.8130701780319214, "scale": 22.378780364990234}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.03131743520498276, "kernel": 25.905803680419922}, "out_proj": {"bias": 1.540170669555664, "kernel": 25.08538818359375}, "q_proj": {"bias": 1.3004887104034424, "kernel": 26.19013023376465}, "v_proj": {"bias": 0.34544575214385986, "kernel": 25.80384635925293}}, "feed_forward": {"intermediate_dense": {"bias": 1.7621228694915771, "kernel": 95.1478042602539}, "output_dense": {"bias": 1.0215306282043457, "kernel": 90.9341049194336}}, "final_layer_norm": {"bias": 1.2818951606750488, "scale": 19.882598876953125}, "layer_norm": {"bias": 3.2894985675811768, "scale": 16.021102905273438}}, "1": {"attention": {"k_proj": {"bias": 0.030251234769821167, "kernel": 40.206363677978516}, "out_proj": {"bias": 1.2913923263549805, "kernel": 41.64579772949219}, "q_proj": {"bias": 2.8594934940338135, "kernel": 40.05158996582031}, "v_proj": {"bias": 0.28184354305267334, "kernel": 40.12499237060547}}, "feed_forward": {"intermediate_dense": {"bias": 1.5795958042144775, "kernel": 93.23052978515625}, "output_dense": {"bias": 0.8024207353591919, "kernel": 84.13617706298828}}, "final_layer_norm": {"bias": 1.1360442638397217, "scale": 18.409835815429688}, "layer_norm": {"bias": 1.7416155338287354, "scale": 19.35036849975586}}, "10": {"attention": {"k_proj": {"bias": 0.04441862925887108, "kernel": 47.289459228515625}, "out_proj": {"bias": 1.220796823501587, "kernel": 50.119895935058594}, "q_proj": {"bias": 2.42268705368042, "kernel": 47.24949645996094}, "v_proj": {"bias": 0.31348833441734314, "kernel": 50.3177490234375}}, "feed_forward": {"intermediate_dense": {"bias": 1.6234996318817139, "kernel": 97.58567810058594}, "output_dense": {"bias": 0.563186526298523, "kernel": 91.52272033691406}}, "final_layer_norm": {"bias": 2.201132297515869, "scale": 20.358192443847656}, "layer_norm": {"bias": 1.6924121379852295, "scale": 22.291942596435547}}, "11": {"attention": {"k_proj": {"bias": 0.09936677664518356, "kernel": 47.08089065551758}, "out_proj": {"bias": 1.0724854469299316, "kernel": 49.310142517089844}, "q_proj": {"bias": 2.474160671234131, "kernel": 46.80278778076172}, "v_proj": {"bias": 0.35527026653289795, "kernel": 49.852378845214844}}, "feed_forward": {"intermediate_dense": {"bias": 1.6722913980484009, "kernel": 98.33531188964844}, "output_dense": {"bias": 0.5459500551223755, "kernel": 93.23268127441406}}, "final_layer_norm": {"bias": 2.182143449783325, "scale": 20.366065979003906}, "layer_norm": {"bias": 1.6736292839050293, "scale": 22.601032257080078}}, "12": {"attention": {"k_proj": {"bias": 0.048039671033620834, "kernel": 47.672481536865234}, "out_proj": {"bias": 1.0567779541015625, "kernel": 49.622215270996094}, "q_proj": {"bias": 2.361361265182495, "kernel": 47.42518615722656}, "v_proj": {"bias": 0.342715322971344, "kernel": 50.03790283203125}}, "feed_forward": {"intermediate_dense": {"bias": 1.7160598039627075, "kernel": 99.17654418945312}, "output_dense": {"bias": 0.5353658199310303, "kernel": 94.77962493896484}}, "final_layer_norm": {"bias": 2.139207363128662, "scale": 20.327411651611328}, "layer_norm": {"bias": 1.737513542175293, "scale": 23.147457122802734}}, "13": {"attention": {"k_proj": {"bias": 0.07115713506937027, "kernel": 49.556034088134766}, "out_proj": {"bias": 1.0499722957611084, "kernel": 49.264190673828125}, "q_proj": {"bias": 2.3340485095977783, "kernel": 49.41236114501953}, "v_proj": {"bias": 0.3689476251602173, "kernel": 49.4439697265625}}, "feed_forward": {"intermediate_dense": {"bias": 1.7703373432159424, "kernel": 99.78363800048828}, "output_dense": {"bias": 0.5515254139900208, "kernel": 95.21337127685547}}, "final_layer_norm": {"bias": 2.0224528312683105, "scale": 20.472827911376953}, "layer_norm": {"bias": 1.8376377820968628, "scale": 23.36989974975586}}, "14": {"attention": {"k_proj": {"bias": 0.15231014788150787, "kernel": 49.75981140136719}, "out_proj": {"bias": 1.2114605903625488, "kernel": 47.699607849121094}, "q_proj": {"bias": 2.3986077308654785, "kernel": 49.81935119628906}, "v_proj": {"bias": 0.36967846751213074, "kernel": 47.28139877319336}}, "feed_forward": {"intermediate_dense": {"bias": 1.807011365890503, "kernel": 100.40362548828125}, "output_dense": {"bias": 0.56741863489151, "kernel": 96.57898712158203}}, "final_layer_norm": {"bias": 2.15641450881958, "scale": 20.617422103881836}, "layer_norm": {"bias": 1.9674854278564453, "scale": 23.542499542236328}}, "15": {"attention": {"k_proj": {"bias": 0.0813792496919632, "kernel": 49.88924789428711}, "out_proj": {"bias": 1.2610225677490234, "kernel": 48.30310821533203}, "q_proj": {"bias": 2.5425076484680176, "kernel": 49.9578742980957}, "v_proj": {"bias": 0.40175360441207886, "kernel": 47.95265197753906}}, "feed_forward": {"intermediate_dense": {"bias": 1.8148760795593262, "kernel": 100.21389770507812}, "output_dense": {"bias": 0.7157906889915466, "kernel": 97.26182556152344}}, "final_layer_norm": {"bias": 2.079712152481079, "scale": 20.710189819335938}, "layer_norm": {"bias": 2.2167983055114746, "scale": 23.696781158447266}}, "16": {"attention": {"k_proj": {"bias": 0.04424087330698967, "kernel": 49.796321868896484}, "out_proj": {"bias": 1.195256233215332, "kernel": 47.77092742919922}, "q_proj": {"bias": 2.625056743621826, "kernel": 49.680885314941406}, "v_proj": {"bias": 0.3579425811767578, "kernel": 47.45350646972656}}, "feed_forward": {"intermediate_dense": {"bias": 1.8125271797180176, "kernel": 100.86848449707031}, "output_dense": {"bias": 0.7396664619445801, "kernel": 98.12709045410156}}, "final_layer_norm": {"bias": 2.1527061462402344, "scale": 21.196475982666016}, "layer_norm": {"bias": 2.1496822834014893, "scale": 22.603479385375977}}, "17": {"attention": {"k_proj": {"bias": 0.03626137971878052, "kernel": 49.99818420410156}, "out_proj": {"bias": 1.137595295906067, "kernel": 47.0876350402832}, "q_proj": {"bias": 2.699965000152588, "kernel": 50.09791564941406}, "v_proj": {"bias": 0.3951329290866852, "kernel": 46.760047912597656}}, "feed_forward": {"intermediate_dense": {"bias": 1.8239432573318481, "kernel": 101.9515151977539}, "output_dense": {"bias": 0.7553638815879822, "kernel": 98.52111053466797}}, "final_layer_norm": {"bias": 2.240124464035034, "scale": 21.756210327148438}, "layer_norm": {"bias": 2.060270309448242, "scale": 22.153118133544922}}, "18": {"attention": {"k_proj": {"bias": 0.072176992893219, "kernel": 50.285491943359375}, "out_proj": {"bias": 1.2414965629577637, "kernel": 48.11112976074219}, "q_proj": {"bias": 2.5908234119415283, "kernel": 50.671958923339844}, "v_proj": {"bias": 0.4236592650413513, "kernel": 47.639381408691406}}, "feed_forward": {"intermediate_dense": {"bias": 1.8664506673812866, "kernel": 102.25341033935547}, "output_dense": {"bias": 0.870419979095459, "kernel": 100.16462707519531}}, "final_layer_norm": {"bias": 2.345587730407715, "scale": 21.717586517333984}, "layer_norm": {"bias": 2.243528366088867, "scale": 23.879430770874023}}, "19": {"attention": {"k_proj": {"bias": 0.03598909452557564, "kernel": 49.554298400878906}, "out_proj": {"bias": 1.2165489196777344, "kernel": 47.99784851074219}, "q_proj": {"bias": 2.8670082092285156, "kernel": 49.98486328125}, "v_proj": {"bias": 0.3872653841972351, "kernel": 47.243560791015625}}, "feed_forward": {"intermediate_dense": {"bias": 1.9215714931488037, "kernel": 102.83073425292969}, "output_dense": {"bias": 0.9350616335868835, "kernel": 101.06717681884766}}, "final_layer_norm": {"bias": 2.303997755050659, "scale": 22.077442169189453}, "layer_norm": {"bias": 2.1655802726745605, "scale": 23.080127716064453}}, "2": {"attention": {"k_proj": {"bias": 0.04701065644621849, "kernel": 46.15753173828125}, "out_proj": {"bias": 1.2106332778930664, "kernel": 43.85963439941406}, "q_proj": {"bias": 3.047680377960205, "kernel": 45.92299270629883}, "v_proj": {"bias": 0.3092108368873596, "kernel": 43.8585090637207}}, "feed_forward": {"intermediate_dense": {"bias": 1.6184449195861816, "kernel": 98.29766845703125}, "output_dense": {"bias": 0.690834641456604, "kernel": 87.27284240722656}}, "final_layer_norm": {"bias": 1.4512405395507812, "scale": 20.990041732788086}, "layer_norm": {"bias": 1.6637616157531738, "scale": 22.045211791992188}}, "20": {"attention": {"k_proj": {"bias": 0.03447665646672249, "kernel": 49.52602005004883}, "out_proj": {"bias": 1.2453045845031738, "kernel": 47.38385009765625}, "q_proj": {"bias": 2.780540943145752, "kernel": 50.31262969970703}, "v_proj": {"bias": 0.3595719635486603, "kernel": 46.284217834472656}}, "feed_forward": {"intermediate_dense": {"bias": 1.9221566915512085, "kernel": 104.09403991699219}, "output_dense": {"bias": 1.050825834274292, "kernel": 101.69520568847656}}, "final_layer_norm": {"bias": 2.328557014465332, "scale": 23.021968841552734}, "layer_norm": {"bias": 2.140803813934326, "scale": 23.223400115966797}}, "21": {"attention": {"k_proj": {"bias": 0.04754549637436867, "kernel": 49.96974182128906}, "out_proj": {"bias": 1.281482458114624, "kernel": 47.423336029052734}, "q_proj": {"bias": 2.7268178462982178, "kernel": 50.81268310546875}, "v_proj": {"bias": 0.4139178991317749, "kernel": 46.53073501586914}}, "feed_forward": {"intermediate_dense": {"bias": 1.964850902557373, "kernel": 104.28791046142578}, "output_dense": {"bias": 1.121903657913208, "kernel": 102.01963806152344}}, "final_layer_norm": {"bias": 2.3586978912353516, "scale": 22.667612075805664}, "layer_norm": {"bias": 2.212503433227539, "scale": 23.49795913696289}}, "22": {"attention": {"k_proj": {"bias": 0.03752344101667404, "kernel": 50.35940933227539}, "out_proj": {"bias": 1.199880838394165, "kernel": 46.882537841796875}, "q_proj": {"bias": 2.8062198162078857, "kernel": 50.742584228515625}, "v_proj": {"bias": 0.367045521736145, "kernel": 46.751190185546875}}, "feed_forward": {"intermediate_dense": {"bias": 1.8952704668045044, "kernel": 104.68394470214844}, "output_dense": {"bias": 1.1314277648925781, "kernel": 101.29869079589844}}, "final_layer_norm": {"bias": 2.2412726879119873, "scale": 22.188982009887695}, "layer_norm": {"bias": 2.207037925720215, "scale": 22.510005950927734}}, "23": {"attention": {"k_proj": {"bias": 0.12592847645282745, "kernel": 51.452911376953125}, "out_proj": {"bias": 1.3290315866470337, "kernel": 47.882537841796875}, "q_proj": {"bias": 2.644528865814209, "kernel": 51.56488037109375}, "v_proj": {"bias": 0.5197697877883911, "kernel": 48.523468017578125}}, "feed_forward": {"intermediate_dense": {"bias": 1.8733866214752197, "kernel": 104.4806900024414}, "output_dense": {"bias": 1.1092435121536255, "kernel": 102.0977783203125}}, "final_layer_norm": {"bias": 2.4936025142669678, "scale": 22.144878387451172}, "layer_norm": {"bias": 2.6934688091278076, "scale": 23.723365783691406}}, "24": {"attention": {"k_proj": {"bias": 0.06529681384563446, "kernel": 49.950862884521484}, "out_proj": {"bias": 1.384519338607788, "kernel": 49.86079406738281}, "q_proj": {"bias": 2.799373149871826, "kernel": 49.94599914550781}, "v_proj": {"bias": 0.4739426374435425, "kernel": 49.93796157836914}}, "feed_forward": {"intermediate_dense": {"bias": 1.9900020360946655, "kernel": 103.95246887207031}, "output_dense": {"bias": 1.1468805074691772, "kernel": 104.9716796875}}, "final_layer_norm": {"bias": 2.5962982177734375, "scale": 22.201526641845703}, "layer_norm": {"bias": 2.4211277961730957, "scale": 23.269819259643555}}, "25": {"attention": {"k_proj": {"bias": 0.057913169264793396, "kernel": 50.48648452758789}, "out_proj": {"bias": 1.1956175565719604, "kernel": 47.771400451660156}, "q_proj": {"bias": 2.876575469970703, "kernel": 50.27440643310547}, "v_proj": {"bias": 0.5550850629806519, "kernel": 48.308677673339844}}, "feed_forward": {"intermediate_dense": {"bias": 1.889784812927246, "kernel": 104.2199478149414}, "output_dense": {"bias": 1.02622652053833, "kernel": 104.90184020996094}}, "final_layer_norm": {"bias": 2.300818681716919, "scale": 22.736614227294922}, "layer_norm": {"bias": 2.575133800506592, "scale": 22.408491134643555}}, "26": {"attention": {"k_proj": {"bias": 0.08028850704431534, "kernel": 50.69630432128906}, "out_proj": {"bias": 1.132831335067749, "kernel": 48.54658508300781}, "q_proj": {"bias": 2.836604595184326, "kernel": 50.46205520629883}, "v_proj": {"bias": 0.490480899810791, "kernel": 49.15398406982422}}, "feed_forward": {"intermediate_dense": {"bias": 1.9836318492889404, "kernel": 103.65016174316406}, "output_dense": {"bias": 0.986747145652771, "kernel": 102.07159423828125}}, "final_layer_norm": {"bias": 1.9350225925445557, "scale": 21.59209442138672}, "layer_norm": {"bias": 2.4815142154693604, "scale": 22.84943389892578}}, "27": {"attention": {"k_proj": {"bias": 0.3756049871444702, "kernel": 51.36669158935547}, "out_proj": {"bias": 1.362156867980957, "kernel": 49.87183380126953}, "q_proj": {"bias": 2.6157610416412354, "kernel": 51.22962951660156}, "v_proj": {"bias": 0.570256233215332, "kernel": 50.32680130004883}}, "feed_forward": {"intermediate_dense": {"bias": 2.144805431365967, "kernel": 101.92828369140625}, "output_dense": {"bias": 0.8671581745147705, "kernel": 101.75527954101562}}, "final_layer_norm": {"bias": 2.2173526287078857, "scale": 20.858638763427734}, "layer_norm": {"bias": 2.5520198345184326, "scale": 23.544979095458984}}, "28": {"attention": {"k_proj": {"bias": 0.41085609793663025, "kernel": 52.29168701171875}, "out_proj": {"bias": 1.3885066509246826, "kernel": 50.63475799560547}, "q_proj": {"bias": 2.7657432556152344, "kernel": 51.936649322509766}, "v_proj": {"bias": 0.4593738615512848, "kernel": 50.96391296386719}}, "feed_forward": {"intermediate_dense": {"bias": 2.091826915740967, "kernel": 101.93879699707031}, "output_dense": {"bias": 0.7690470218658447, "kernel": 103.94679260253906}}, "final_layer_norm": {"bias": 2.127680778503418, "scale": 21.179397583007812}, "layer_norm": {"bias": 2.052739143371582, "scale": 24.402130126953125}}, "29": {"attention": {"k_proj": {"bias": 0.07771217077970505, "kernel": 48.74781036376953}, "out_proj": {"bias": 1.3660151958465576, "kernel": 53.145225524902344}, "q_proj": {"bias": 2.740656852722168, "kernel": 48.56964111328125}, "v_proj": {"bias": 0.4171416759490967, "kernel": 53.050506591796875}}, "feed_forward": {"intermediate_dense": {"bias": 2.091569185256958, "kernel": 102.61178588867188}, "output_dense": {"bias": 0.8723453879356384, "kernel": 108.2125244140625}}, "final_layer_norm": {"bias": 2.3656370639801025, "scale": 22.315515518188477}, "layer_norm": {"bias": 2.1508841514587402, "scale": 25.375347137451172}}, "3": {"attention": {"k_proj": {"bias": 0.1246776133775711, "kernel": 50.125587463378906}, "out_proj": {"bias": 1.3616056442260742, "kernel": 46.49916076660156}, "q_proj": {"bias": 2.720198392868042, "kernel": 50.353397369384766}, "v_proj": {"bias": 0.2983784079551697, "kernel": 46.898887634277344}}, "feed_forward": {"intermediate_dense": {"bias": 1.633817434310913, "kernel": 99.93789672851562}, "output_dense": {"bias": 0.651297926902771, "kernel": 90.12864685058594}}, "final_layer_norm": {"bias": 1.7128074169158936, "scale": 21.085018157958984}, "layer_norm": {"bias": 1.8274308443069458, "scale": 23.583423614501953}}, "30": {"attention": {"k_proj": {"bias": 0.2584836781024933, "kernel": 50.669368743896484}, "out_proj": {"bias": 1.160850167274475, "kernel": 49.42723846435547}, "q_proj": {"bias": 2.80226469039917, "kernel": 50.75163650512695}, "v_proj": {"bias": 0.48041021823883057, "kernel": 49.77007293701172}}, "feed_forward": {"intermediate_dense": {"bias": 2.025425910949707, "kernel": 103.11666870117188}, "output_dense": {"bias": 0.824256420135498, "kernel": 107.20564270019531}}, "final_layer_norm": {"bias": 2.1880717277526855, "scale": 23.454082489013672}, "layer_norm": {"bias": 2.2975010871887207, "scale": 25.109004974365234}}, "31": {"attention": {"k_proj": {"bias": 0.3566759526729584, "kernel": 49.19684600830078}, "out_proj": {"bias": 1.0874900817871094, "kernel": 50.291221618652344}, "q_proj": {"bias": 2.583777904510498, "kernel": 49.29847717285156}, "v_proj": {"bias": 0.5268669724464417, "kernel": 50.42063522338867}}, "feed_forward": {"intermediate_dense": {"bias": 2.105203628540039, "kernel": 101.79988098144531}, "output_dense": {"bias": 1.0006189346313477, "kernel": 104.59204864501953}}, "final_layer_norm": {"bias": 2.082390308380127, "scale": 23.346904754638672}, "layer_norm": {"bias": 2.2939162254333496, "scale": 24.885295867919922}}, "32": {"attention": {"k_proj": {"bias": 0.209731787443161, "kernel": 48.034210205078125}, "out_proj": {"bias": 1.094763994216919, "kernel": 49.47209548950195}, "q_proj": {"bias": 2.8436858654022217, "kernel": 48.01784896850586}, "v_proj": {"bias": 0.3943668007850647, "kernel": 49.758872985839844}}, "feed_forward": {"intermediate_dense": {"bias": 2.0349297523498535, "kernel": 100.66475677490234}, "output_dense": {"bias": 1.0609936714172363, "kernel": 103.93878173828125}}, "final_layer_norm": {"bias": 2.040682792663574, "scale": 23.78135871887207}, "layer_norm": {"bias": 2.246428966522217, "scale": 25.143821716308594}}, "33": {"attention": {"k_proj": {"bias": 0.2134949266910553, "kernel": 47.9801025390625}, "out_proj": {"bias": 1.1317718029022217, "kernel": 49.322593688964844}, "q_proj": {"bias": 2.985382080078125, "kernel": 47.97284698486328}, "v_proj": {"bias": 0.42412418127059937, "kernel": 49.58897018432617}}, "feed_forward": {"intermediate_dense": {"bias": 2.043097496032715, "kernel": 99.04029846191406}, "output_dense": {"bias": 1.034851312637329, "kernel": 102.71003723144531}}, "final_layer_norm": {"bias": 1.9527044296264648, "scale": 23.548221588134766}, "layer_norm": {"bias": 2.4380321502685547, "scale": 25.394405364990234}}, "34": {"attention": {"k_proj": {"bias": 0.2292483150959015, "kernel": 47.18817901611328}, "out_proj": {"bias": 1.3792448043823242, "kernel": 50.80772399902344}, "q_proj": {"bias": 2.868629217147827, "kernel": 47.24061965942383}, "v_proj": {"bias": 0.3926468789577484, "kernel": 50.74346160888672}}, "feed_forward": {"intermediate_dense": {"bias": 2.124262809753418, "kernel": 97.87345886230469}, "output_dense": {"bias": 0.9663103818893433, "kernel": 102.0224609375}}, "final_layer_norm": {"bias": 1.8948910236358643, "scale": 23.202611923217773}, "layer_norm": {"bias": 2.5170724391937256, "scale": 25.778274536132812}}, "35": {"attention": {"k_proj": {"bias": 0.35923632979393005, "kernel": 48.91413879394531}, "out_proj": {"bias": 1.2987806797027588, "kernel": 49.6588134765625}, "q_proj": {"bias": 2.6141324043273926, "kernel": 49.24787139892578}, "v_proj": {"bias": 0.4785917401313782, "kernel": 49.48438262939453}}, "feed_forward": {"intermediate_dense": {"bias": 2.2050743103027344, "kernel": 96.481201171875}, "output_dense": {"bias": 0.8614839315414429, "kernel": 100.76122283935547}}, "final_layer_norm": {"bias": 1.9744343757629395, "scale": 23.32772445678711}, "layer_norm": {"bias": 2.2853195667266846, "scale": 26.26994514465332}}, "36": {"attention": {"k_proj": {"bias": 0.19287711381912231, "kernel": 46.228477478027344}, "out_proj": {"bias": 1.338456153869629, "kernel": 50.999366760253906}, "q_proj": {"bias": 2.700589418411255, "kernel": 46.217376708984375}, "v_proj": {"bias": 0.362974613904953, "kernel": 51.18370819091797}}, "feed_forward": {"intermediate_dense": {"bias": 2.079521656036377, "kernel": 95.5786361694336}, "output_dense": {"bias": 0.895334005355835, "kernel": 100.45399475097656}}, "final_layer_norm": {"bias": 1.6194080114364624, "scale": 23.851945877075195}, "layer_norm": {"bias": 2.0079894065856934, "scale": 25.780502319335938}}, "37": {"attention": {"k_proj": {"bias": 0.5275993347167969, "kernel": 45.26773452758789}, "out_proj": {"bias": 1.5979468822479248, "kernel": 50.983360290527344}, "q_proj": {"bias": 2.3945603370666504, "kernel": 45.33816146850586}, "v_proj": {"bias": 0.3589479327201843, "kernel": 50.853118896484375}}, "feed_forward": {"intermediate_dense": {"bias": 1.975203037261963, "kernel": 94.83090209960938}, "output_dense": {"bias": 0.9040964841842651, "kernel": 100.22051239013672}}, "final_layer_norm": {"bias": 1.4460301399230957, "scale": 24.251632690429688}, "layer_norm": {"bias": 1.977910041809082, "scale": 25.81577491760254}}, "38": {"attention": {"k_proj": {"bias": 0.6142090559005737, "kernel": 43.45698547363281}, "out_proj": {"bias": 1.2978699207305908, "kernel": 50.466609954833984}, "q_proj": {"bias": 2.3290982246398926, "kernel": 43.46769714355469}, "v_proj": {"bias": 0.4176805019378662, "kernel": 50.338287353515625}}, "feed_forward": {"intermediate_dense": {"bias": 1.9197394847869873, "kernel": 92.87792205810547}, "output_dense": {"bias": 0.8916142582893372, "kernel": 98.47053527832031}}, "final_layer_norm": {"bias": 1.4941749572753906, "scale": 24.96930503845215}, "layer_norm": {"bias": 2.1559910774230957, "scale": 26.533889770507812}}, "39": {"attention": {"k_proj": {"bias": 0.643582820892334, "kernel": 43.23234939575195}, "out_proj": {"bias": 1.5914864540100098, "kernel": 50.338294982910156}, "q_proj": {"bias": 2.1118969917297363, "kernel": 43.619422912597656}, "v_proj": {"bias": 0.38760751485824585, "kernel": 50.01036834716797}}, "feed_forward": {"intermediate_dense": {"bias": 1.9129672050476074, "kernel": 91.19747924804688}, "output_dense": {"bias": 0.9707897901535034, "kernel": 98.85185241699219}}, "final_layer_norm": {"bias": 1.6381645202636719, "scale": 25.602428436279297}, "layer_norm": {"bias": 2.1341910362243652, "scale": 27.172517776489258}}, "4": {"attention": {"k_proj": {"bias": 0.1383555829524994, "kernel": 52.68865966796875}, "out_proj": {"bias": 1.542163372039795, "kernel": 47.90272521972656}, "q_proj": {"bias": 2.522325277328491, "kernel": 52.87029266357422}, "v_proj": {"bias": 0.34495460987091064, "kernel": 48.26051712036133}}, "feed_forward": {"intermediate_dense": {"bias": 1.6215894222259521, "kernel": 99.52203369140625}, "output_dense": {"bias": 0.8154100179672241, "kernel": 91.35391235351562}}, "final_layer_norm": {"bias": 1.795678734779358, "scale": 20.617877960205078}, "layer_norm": {"bias": 1.9197185039520264, "scale": 23.964736938476562}}, "40": {"attention": {"k_proj": {"bias": 0.5853625535964966, "kernel": 42.58599853515625}, "out_proj": {"bias": 1.5349242687225342, "kernel": 48.988765716552734}, "q_proj": {"bias": 2.0470433235168457, "kernel": 43.351776123046875}, "v_proj": {"bias": 0.4405587315559387, "kernel": 48.566322326660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.7741196155548096, "kernel": 89.46418762207031}, "output_dense": {"bias": 1.022929072380066, "kernel": 96.10987854003906}}, "final_layer_norm": {"bias": 1.8011051416397095, "scale": 24.87212371826172}, "layer_norm": {"bias": 2.0766186714172363, "scale": 26.71459197998047}}, "41": {"attention": {"k_proj": {"bias": 1.6699955463409424, "kernel": 39.9322509765625}, "out_proj": {"bias": 1.2970274686813354, "kernel": 50.54954528808594}, "q_proj": {"bias": 1.7257652282714844, "kernel": 40.694183349609375}, "v_proj": {"bias": 0.39571413397789, "kernel": 49.50176239013672}}, "feed_forward": {"intermediate_dense": {"bias": 1.913979411125183, "kernel": 86.25863647460938}, "output_dense": {"bias": 1.0471625328063965, "kernel": 95.15855407714844}}, "final_layer_norm": {"bias": 2.2990684509277344, "scale": 28.325096130371094}, "layer_norm": {"bias": 2.106232166290283, "scale": 28.505508422851562}}, "42": {"attention": {"k_proj": {"bias": 0.7960388660430908, "kernel": 36.71638488769531}, "out_proj": {"bias": 1.33872652053833, "kernel": 44.78422927856445}, "q_proj": {"bias": 1.5474038124084473, "kernel": 38.06586837768555}, "v_proj": {"bias": 0.5880352258682251, "kernel": 43.13159942626953}}, "feed_forward": {"intermediate_dense": {"bias": 1.6526720523834229, "kernel": 85.25054931640625}, "output_dense": {"bias": 1.0995839834213257, "kernel": 93.36160278320312}}, "final_layer_norm": {"bias": 2.0212883949279785, "scale": 29.624130249023438}, "layer_norm": {"bias": 1.571912407875061, "scale": 27.37674331665039}}, "43": {"attention": {"k_proj": {"bias": 1.2095569372177124, "kernel": 33.2391242980957}, "out_proj": {"bias": 1.3312859535217285, "kernel": 41.17816162109375}, "q_proj": {"bias": 1.3585820198059082, "kernel": 34.052093505859375}, "v_proj": {"bias": 0.517215371131897, "kernel": 39.074363708496094}}, "feed_forward": {"intermediate_dense": {"bias": 1.686711311340332, "kernel": 84.47602844238281}, "output_dense": {"bias": 0.8660041093826294, "kernel": 91.29061126708984}}, "final_layer_norm": {"bias": 1.947145938873291, "scale": 31.841339111328125}, "layer_norm": {"bias": 1.6917033195495605, "scale": 25.53169059753418}}, "44": {"attention": {"k_proj": {"bias": 2.490297317504883, "kernel": 33.829734802246094}, "out_proj": {"bias": 1.0940178632736206, "kernel": 44.90290832519531}, "q_proj": {"bias": 1.290268898010254, "kernel": 34.20195388793945}, "v_proj": {"bias": 0.3792087137699127, "kernel": 43.996761322021484}}, "feed_forward": {"intermediate_dense": {"bias": 1.7657651901245117, "kernel": 83.43031311035156}, "output_dense": {"bias": 0.8124760985374451, "kernel": 88.94329833984375}}, "final_layer_norm": {"bias": 1.9339977502822876, "scale": 34.01377868652344}, "layer_norm": {"bias": 1.5856916904449463, "scale": 25.54978370666504}}, "45": {"attention": {"k_proj": {"bias": 2.0486204624176025, "kernel": 33.6713981628418}, "out_proj": {"bias": 0.9801490902900696, "kernel": 48.500160217285156}, "q_proj": {"bias": 1.370434284210205, "kernel": 33.85618591308594}, "v_proj": {"bias": 0.4300559461116791, "kernel": 48.65924835205078}}, "feed_forward": {"intermediate_dense": {"bias": 1.8838140964508057, "kernel": 80.09923553466797}, "output_dense": {"bias": 0.9479005932807922, "kernel": 84.33807373046875}}, "final_layer_norm": {"bias": 1.6806273460388184, "scale": 32.72142791748047}, "layer_norm": {"bias": 1.5162461996078491, "scale": 24.066335678100586}}, "46": {"attention": {"k_proj": {"bias": 1.5381855964660645, "kernel": 34.851234436035156}, "out_proj": {"bias": 0.7452247142791748, "kernel": 50.928497314453125}, "q_proj": {"bias": 1.5366780757904053, "kernel": 34.97130584716797}, "v_proj": {"bias": 0.3714127540588379, "kernel": 51.681663513183594}}, "feed_forward": {"intermediate_dense": {"bias": 1.9417197704315186, "kernel": 74.44387817382812}, "output_dense": {"bias": 1.1015146970748901, "kernel": 74.64437866210938}}, "final_layer_norm": {"bias": 1.670118808746338, "scale": 28.23578453063965}, "layer_norm": {"bias": 1.3335180282592773, "scale": 22.981168746948242}}, "47": {"attention": {"k_proj": {"bias": 0.26057976484298706, "kernel": 37.12379837036133}, "out_proj": {"bias": 0.6302802562713623, "kernel": 45.20001983642578}, "q_proj": {"bias": 1.6586006879806519, "kernel": 37.767982482910156}, "v_proj": {"bias": 0.34656229615211487, "kernel": 46.186058044433594}}, "feed_forward": {"intermediate_dense": {"bias": 1.9915863275527954, "kernel": 71.79125213623047}, "output_dense": {"bias": 0.6057736873626709, "kernel": 68.15716552734375}}, "final_layer_norm": {"bias": 1.5099494457244873, "scale": 23.07884407043457}, "layer_norm": {"bias": 1.059046745300293, "scale": 20.232519149780273}}, "5": {"attention": {"k_proj": {"bias": 0.03299503028392792, "kernel": 48.02180480957031}, "out_proj": {"bias": 1.5268709659576416, "kernel": 49.123374938964844}, "q_proj": {"bias": 2.615450859069824, "kernel": 48.16506576538086}, "v_proj": {"bias": 0.3084052801132202, "kernel": 49.91838836669922}}, "feed_forward": {"intermediate_dense": {"bias": 1.5458447933197021, "kernel": 99.63191223144531}, "output_dense": {"bias": 0.8447100520133972, "kernel": 90.65230560302734}}, "final_layer_norm": {"bias": 2.073741912841797, "scale": 20.830238342285156}, "layer_norm": {"bias": 1.9506666660308838, "scale": 23.363054275512695}}, "6": {"attention": {"k_proj": {"bias": 0.2025420367717743, "kernel": 49.65654754638672}, "out_proj": {"bias": 1.5175342559814453, "kernel": 48.44850158691406}, "q_proj": {"bias": 2.6657018661499023, "kernel": 50.13703918457031}, "v_proj": {"bias": 0.31043344736099243, "kernel": 48.978675842285156}}, "feed_forward": {"intermediate_dense": {"bias": 1.5249176025390625, "kernel": 98.72537994384766}, "output_dense": {"bias": 0.6966495513916016, "kernel": 90.24320220947266}}, "final_layer_norm": {"bias": 2.373417854309082, "scale": 20.306549072265625}, "layer_norm": {"bias": 1.953676700592041, "scale": 23.746742248535156}}, "7": {"attention": {"k_proj": {"bias": 0.19703075289726257, "kernel": 49.45228576660156}, "out_proj": {"bias": 1.3325653076171875, "kernel": 48.69895935058594}, "q_proj": {"bias": 2.4418447017669678, "kernel": 49.841575622558594}, "v_proj": {"bias": 0.39479926228523254, "kernel": 48.66286849975586}}, "feed_forward": {"intermediate_dense": {"bias": 1.5313714742660522, "kernel": 98.47097778320312}, "output_dense": {"bias": 0.537929117679596, "kernel": 89.98369598388672}}, "final_layer_norm": {"bias": 2.213186264038086, "scale": 20.54340171813965}, "layer_norm": {"bias": 1.8572309017181396, "scale": 22.469369888305664}}, "8": {"attention": {"k_proj": {"bias": 0.1752677857875824, "kernel": 48.950439453125}, "out_proj": {"bias": 1.1584160327911377, "kernel": 49.24664306640625}, "q_proj": {"bias": 2.4162092208862305, "kernel": 48.71965789794922}, "v_proj": {"bias": 0.3242269456386566, "kernel": 49.42983627319336}}, "feed_forward": {"intermediate_dense": {"bias": 1.585384726524353, "kernel": 98.07679748535156}, "output_dense": {"bias": 0.4949800968170166, "kernel": 89.4083251953125}}, "final_layer_norm": {"bias": 2.1687498092651367, "scale": 20.330463409423828}, "layer_norm": {"bias": 1.7944811582565308, "scale": 22.929916381835938}}, "9": {"attention": {"k_proj": {"bias": 0.20956851541996002, "kernel": 49.57889938354492}, "out_proj": {"bias": 1.361301302909851, "kernel": 50.04251480102539}, "q_proj": {"bias": 2.375364065170288, "kernel": 49.73973846435547}, "v_proj": {"bias": 0.3334028422832489, "kernel": 50.451393127441406}}, "feed_forward": {"intermediate_dense": {"bias": 1.667464256286621, "kernel": 96.68904113769531}, "output_dense": {"bias": 0.635001540184021, "kernel": 89.95124816894531}}, "final_layer_norm": {"bias": 2.06015682220459, "scale": 19.616397857666016}, "layer_norm": {"bias": 1.8842439651489258, "scale": 24.3007869720459}}}, "pos_conv_embed": {"conv": {"bias": 5.55006217956543, "weight_g": 8.820418357849121, "weight_v": 84.62606811523438}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.31657600402832, "scale": 16.55396270751953}, "projection": {"bias": 1.6528528928756714, "kernel": 34.70302963256836}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 1.9923363652196713e-05, "train/loss": 0.7011880874633789, "train/param_norm": 1186.244384765625, "_runtime": 11972, "_timestamp": 1660225044, "_step": 4900} \ No newline at end of file diff --git a/wandb/run-20220811_101752-mzjvp6ho/logs/debug-internal.log b/wandb/run-20220811_101752-mzjvp6ho/logs/debug-internal.log new file mode 100644 index 0000000000000000000000000000000000000000..3b32ca6c0a1795e92a5b2b4cd7c304753a83fd37 --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/logs/debug-internal.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a4dbdcd97fce536189b76955b89f1c8d89a257f41b65af8725fd9b31c3c85fe +size 871534 diff --git a/wandb/run-20220811_101752-mzjvp6ho/logs/debug.log b/wandb/run-20220811_101752-mzjvp6ho/logs/debug.log new file mode 100644 index 0000000000000000000000000000000000000000..09e22995fe65d0269d9d201f71d29429ce129f54 --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/logs/debug.log @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8aa5aa51ba18ae388c7e80e7807dc48688f60241b1a194d71f5874c6057528b5 +size 2667 diff --git a/wandb/run-20220811_101752-mzjvp6ho/run-mzjvp6ho.wandb b/wandb/run-20220811_101752-mzjvp6ho/run-mzjvp6ho.wandb new file mode 100644 index 0000000000000000000000000000000000000000..6d5eb953ab9ce64b51d96aaa5dd6ea574443822e --- /dev/null +++ b/wandb/run-20220811_101752-mzjvp6ho/run-mzjvp6ho.wandb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a27b96625bf0a9670d38dde15787c631d69ba2791f40f6098124958854c8f5ed +size 4123380