|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Convert a Flax training state to HF Transformers Whisper weights. |
|
""" |
|
|
|
import logging |
|
import os |
|
import sys |
|
from dataclasses import field |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
import optax |
|
from flax import jax_utils, traverse_util |
|
from flax.serialization import from_bytes |
|
from flax.training import train_state |
|
from flax.training.common_utils import shard_prng_key |
|
from huggingface_hub import Repository, create_repo |
|
from optax._src import linear_algebra |
|
from transformers import ( |
|
AutoConfig, |
|
HfArgumentParser, |
|
Seq2SeqTrainingArguments, |
|
) |
|
from transformers.file_utils import get_full_repo_name |
|
from transformers.utils import check_min_version |
|
from transformers.utils.versions import require_version |
|
|
|
from distil_whisper import FlaxWhisperForConditionalGeneration |
|
|
|
|
|
|
|
jax.distributed.initialize() |
|
|
|
|
|
check_min_version("4.27.0.dev0") |
|
|
|
require_version( |
|
"datasets>=1.18.0", |
|
"To fix: pip install -r examples/flax/speech-recogintion/requirements.txt", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@flax.struct.dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
|
""" |
|
|
|
model_name_or_path: str = field( |
|
metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")} |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Pretrained config name or path if not the same as model_name"}, |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")}, |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")}, |
|
) |
|
model_revision: str = field( |
|
default="main", |
|
metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")}, |
|
) |
|
use_auth_token: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Will use the token generated when running `transformers-cli login`" |
|
" (necessary to use this script with private models)." |
|
) |
|
}, |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": ( |
|
"Floating-point format in which the model weights should be initialized" |
|
" and trained. Choose one of `[float32, float16, bfloat16]`." |
|
) |
|
}, |
|
) |
|
load_with_scan_weights: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned " |
|
"weights, defaults to False for non-scan (unrolled) weights." |
|
}, |
|
) |
|
use_scan: bool = field( |
|
default=True, |
|
metadata={"help": ("Whether or not to use `scan_with_axes` over the encoder and decoder blocks.")}, |
|
) |
|
|
|
|
|
def create_learning_rate_fn( |
|
num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float |
|
) -> Callable[[int], jnp.array]: |
|
"""Returns a linear warmup, linear_decay learning rate function.""" |
|
lr_scheduler_types = ("linear", "constant_with_warmup") |
|
|
|
if lr_scheduler_type not in lr_scheduler_types: |
|
raise ValueError( |
|
f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}." |
|
) |
|
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) |
|
decay_fn = optax.linear_schedule( |
|
init_value=learning_rate, |
|
end_value=0 if lr_scheduler_type == "linear" else learning_rate, |
|
transition_steps=num_train_steps - num_warmup_steps, |
|
) |
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) |
|
return schedule_fn |
|
|
|
|
|
class TrainState(train_state.TrainState): |
|
dropout_rng: jnp.ndarray |
|
max_grad_norm: float |
|
|
|
def apply_gradients(self, *, grads, **kwargs): |
|
"""Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the |
|
gradients by the maximum grad norm. |
|
|
|
Note that internally this function calls `.tx.update()` followed by a call |
|
to `optax.apply_updates()` to update `params` and `opt_state`. |
|
|
|
Args: |
|
grads: Gradients that have the same pytree structure as `.params`. |
|
**kwargs: Additional dataclass attributes that should be `.replace()`-ed. |
|
|
|
Returns: |
|
An updated instance of `self` with `step` incremented by one, `params` |
|
and `opt_state` updated by applying `grads`, and additional attributes |
|
replaced as specified by `kwargs`. |
|
""" |
|
|
|
g_norm = linear_algebra.global_norm(grads) |
|
g_norm = jnp.maximum(self.max_grad_norm, g_norm) |
|
grads = jax.tree_map(lambda t: (t / g_norm) * self.max_grad_norm, grads) |
|
|
|
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) |
|
new_params = optax.apply_updates(self.params, updates) |
|
|
|
return self.replace( |
|
step=self.step + 1, |
|
params=new_params, |
|
opt_state=new_opt_state, |
|
**kwargs, |
|
) |
|
|
|
def replicate(self): |
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) |
|
|
|
def unreplicate(self): |
|
return jax_utils.unreplicate(self) |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser( |
|
( |
|
ModelArguments, |
|
Seq2SeqTrainingArguments, |
|
) |
|
) |
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
if training_args.push_to_hub: |
|
if training_args.hub_model_id is None: |
|
repo_name = get_full_repo_name( |
|
Path(training_args.output_dir).absolute().name, |
|
token=training_args.hub_token, |
|
) |
|
else: |
|
repo_name = training_args.hub_model_id |
|
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) |
|
repo = Repository( |
|
training_args.output_dir, |
|
clone_from=repo_name, |
|
token=training_args.hub_token, |
|
) |
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
(model_args.config_name if model_args.config_name else model_args.model_name_or_path), |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained( |
|
model_args.model_name_or_path, |
|
config=config, |
|
dtype=getattr(jnp, model_args.dtype), |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
_do_init=False, |
|
use_scan=model_args.load_with_scan_weights, |
|
) |
|
|
|
|
|
if model_args.use_scan: |
|
student_model.enable_scan() |
|
student_params = student_model.convert_unroll_to_scan(student_params) |
|
|
|
|
|
rng = jax.random.PRNGKey(training_args.seed) |
|
rng, dropout_rng = jax.random.split(rng) |
|
|
|
total_train_steps = int(training_args.max_steps) |
|
|
|
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn( |
|
total_train_steps, |
|
training_args.lr_scheduler_type, |
|
training_args.warmup_steps, |
|
training_args.learning_rate, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
|
|
layer_norm_candidates = [ |
|
"layer_norm", |
|
"self_attn_layer_norm", |
|
"final_layer_norm", |
|
"encoder_attn_layer_norm", |
|
] |
|
layer_norm_named_params = { |
|
layer[-2:] |
|
for layer_norm_name in layer_norm_candidates |
|
for layer in flat_params.keys() |
|
if layer_norm_name in "".join(layer).lower() |
|
} |
|
flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
|
|
adamw = optax.adamw( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=training_args.adam_epsilon, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
|
|
student_state = TrainState.create( |
|
apply_fn=student_model.__call__, |
|
params=student_params, |
|
tx=adamw, |
|
dropout_rng=dropout_rng, |
|
max_grad_norm=training_args.max_grad_norm, |
|
) |
|
|
|
if training_args.resume_from_checkpoint is not None: |
|
if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")): |
|
logger.info( |
|
f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid " |
|
"this behavior, omit the resume_from_checkpoint argument." |
|
) |
|
with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f: |
|
student_state = from_bytes(student_state, f.read()) |
|
else: |
|
logger.warning( |
|
f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure " |
|
f"you pass the path to a folder with a valid checkpoint for your model." |
|
) |
|
|
|
cur_step = int(jax.device_get(student_state.step)) |
|
|
|
|
|
if jax.process_index() == 0: |
|
student_model.disable_scan() |
|
student_state_params = student_model.convert_scan_to_unroll(student_state.params) |
|
student_params = jax.device_get(student_state_params) |
|
student_model.save_pretrained( |
|
os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params |
|
) |
|
if training_args.push_to_hub: |
|
repo.push_to_hub( |
|
commit_message=f"Saving weights of step {cur_step}", |
|
blocking=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|