NeMo_Canary / scripts /avlm /avlm_nemo_run.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import nemo_run as run
from nemo.collections import avlm
def configure_recipe(
nodes: int = 1,
gpus_per_node: int = 8,
pretrain=False,
language_model_from_pretrained=None,
checkpoint_path=None,
output_dir=None,
freeze_modules=None,
):
"""Configure the recipe"""
if pretrain:
recipe = avlm.avlm_8b.pretrain_recipe(
dir=output_dir, # Path to store checkpoints
name="avlm_pretrain",
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
language_model_from_pretrained=language_model_from_pretrained,
freeze_modules=freeze_modules,
)
else:
recipe = avlm.avlm_8b.finetune_recipe(
checkpoint_path=checkpoint_path,
dir=output_dir, # Path to store checkpoints
name="avlm_finetune",
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
freeze_modules=freeze_modules,
peft_scheme="none",
)
recipe.trainer.max_steps = 20
recipe.trainer.val_check_interval = 20
return recipe
def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor:
# pylint: disable=C0115,C0116
# Env vars for jobs are configured here
env_vars = {}
executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)
return executor
def run_pretraining(language_model_from_pretrained=None, checkpoint_path=None, output_dir=None, freeze_modules=None):
# pylint: disable=C0115,C0116
recipe = configure_recipe(
pretrain=True,
language_model_from_pretrained=language_model_from_pretrained,
checkpoint_path=checkpoint_path,
output_dir=output_dir,
freeze_modules=freeze_modules,
)
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)
run.run(recipe, executor=executor)
def run_finetuning(checkpoint_path=None, output_dir=None, freeze_modules=None):
# pylint: disable=C0115,C0116
recipe = configure_recipe(
pretrain=False, checkpoint_path=checkpoint_path, output_dir=output_dir, freeze_modules=freeze_modules
)
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)
run.run(recipe, executor=executor)
# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script with two optional arguments.")
parser.add_argument(
"--training_mode",
type=str,
required=True,
choices=["pretrain", "finetune"],
help="Training mode - either 'pretrain' or 'finetune'",
)
parser.add_argument(
"--language_model_from_pretrained",
type=str,
default=None,
required=False,
help="Path to pretrained language model (optional).",
)
parser.add_argument(
"--checkpoint_path", type=str, default=None, required=False, help="Path to checkpoint (optional)."
)
parser.add_argument(
"--output_dir", type=str, default="./outputs/checkpoints/avlm", help="Path to store checkpoints (optional)."
)
parser.add_argument("--unfreeze_language_model", action="store_true", help="Unfreeze language model (optional).")
parser.add_argument("--unfreeze_vision_model", action="store_true", help="Unfreeze vision model (optional).")
parser.add_argument("--unfreeze_audio_model", action="store_true", help="Unfreeze audio model (optional).")
parser.add_argument(
"--unfreeze_vision_projection", action="store_true", help="Unfreeze vision projection (optional)."
)
parser.add_argument(
"--unfreeze_audio_projection", action="store_true", help="Unfreeze audio projection (optional)."
)
args = parser.parse_args()
# run nemo_run
freeze_modules = {
"freeze_language_model": not args.unfreeze_language_model,
"freeze_vision_model": not args.unfreeze_vision_model,
"freeze_audio_model": not args.unfreeze_audio_model,
"freeze_vision_projection": not args.unfreeze_vision_projection,
"freeze_audio_projection": not args.unfreeze_audio_projection,
}
if args.training_mode == "pretrain":
run_pretraining(
language_model_from_pretrained=args.language_model_from_pretrained,
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
freeze_modules=freeze_modules,
)
elif args.training_mode == "finetune":
run_finetuning(
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
freeze_modules=freeze_modules,
)