|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import nemo_run as run |
|
|
|
from nemo.collections import avlm |
|
|
|
|
|
def configure_recipe( |
|
nodes: int = 1, |
|
gpus_per_node: int = 8, |
|
pretrain=False, |
|
language_model_from_pretrained=None, |
|
checkpoint_path=None, |
|
output_dir=None, |
|
freeze_modules=None, |
|
): |
|
"""Configure the recipe""" |
|
if pretrain: |
|
recipe = avlm.avlm_8b.pretrain_recipe( |
|
dir=output_dir, |
|
name="avlm_pretrain", |
|
num_nodes=nodes, |
|
num_gpus_per_node=gpus_per_node, |
|
language_model_from_pretrained=language_model_from_pretrained, |
|
freeze_modules=freeze_modules, |
|
) |
|
else: |
|
recipe = avlm.avlm_8b.finetune_recipe( |
|
checkpoint_path=checkpoint_path, |
|
dir=output_dir, |
|
name="avlm_finetune", |
|
num_nodes=nodes, |
|
num_gpus_per_node=gpus_per_node, |
|
freeze_modules=freeze_modules, |
|
peft_scheme="none", |
|
) |
|
recipe.trainer.max_steps = 20 |
|
recipe.trainer.val_check_interval = 20 |
|
return recipe |
|
|
|
|
|
def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor: |
|
|
|
|
|
env_vars = {} |
|
|
|
executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) |
|
|
|
return executor |
|
|
|
|
|
def run_pretraining(language_model_from_pretrained=None, checkpoint_path=None, output_dir=None, freeze_modules=None): |
|
|
|
recipe = configure_recipe( |
|
pretrain=True, |
|
language_model_from_pretrained=language_model_from_pretrained, |
|
checkpoint_path=checkpoint_path, |
|
output_dir=output_dir, |
|
freeze_modules=freeze_modules, |
|
) |
|
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) |
|
|
|
run.run(recipe, executor=executor) |
|
|
|
|
|
def run_finetuning(checkpoint_path=None, output_dir=None, freeze_modules=None): |
|
|
|
recipe = configure_recipe( |
|
pretrain=False, checkpoint_path=checkpoint_path, output_dir=output_dir, freeze_modules=freeze_modules |
|
) |
|
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) |
|
|
|
run.run(recipe, executor=executor) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description="Script with two optional arguments.") |
|
parser.add_argument( |
|
"--training_mode", |
|
type=str, |
|
required=True, |
|
choices=["pretrain", "finetune"], |
|
help="Training mode - either 'pretrain' or 'finetune'", |
|
) |
|
parser.add_argument( |
|
"--language_model_from_pretrained", |
|
type=str, |
|
default=None, |
|
required=False, |
|
help="Path to pretrained language model (optional).", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", type=str, default=None, required=False, help="Path to checkpoint (optional)." |
|
) |
|
parser.add_argument( |
|
"--output_dir", type=str, default="./outputs/checkpoints/avlm", help="Path to store checkpoints (optional)." |
|
) |
|
parser.add_argument("--unfreeze_language_model", action="store_true", help="Unfreeze language model (optional).") |
|
parser.add_argument("--unfreeze_vision_model", action="store_true", help="Unfreeze vision model (optional).") |
|
parser.add_argument("--unfreeze_audio_model", action="store_true", help="Unfreeze audio model (optional).") |
|
parser.add_argument( |
|
"--unfreeze_vision_projection", action="store_true", help="Unfreeze vision projection (optional)." |
|
) |
|
parser.add_argument( |
|
"--unfreeze_audio_projection", action="store_true", help="Unfreeze audio projection (optional)." |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
freeze_modules = { |
|
"freeze_language_model": not args.unfreeze_language_model, |
|
"freeze_vision_model": not args.unfreeze_vision_model, |
|
"freeze_audio_model": not args.unfreeze_audio_model, |
|
"freeze_vision_projection": not args.unfreeze_vision_projection, |
|
"freeze_audio_projection": not args.unfreeze_audio_projection, |
|
} |
|
if args.training_mode == "pretrain": |
|
run_pretraining( |
|
language_model_from_pretrained=args.language_model_from_pretrained, |
|
checkpoint_path=args.checkpoint_path, |
|
output_dir=args.output_dir, |
|
freeze_modules=freeze_modules, |
|
) |
|
elif args.training_mode == "finetune": |
|
run_finetuning( |
|
checkpoint_path=args.checkpoint_path, |
|
output_dir=args.output_dir, |
|
freeze_modules=freeze_modules, |
|
) |
|
|