NeMo_Canary / scripts /avlm /avlm_pretrain.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
torchrun --nproc_per_node=8 scripts/vlm/avlm_pretrain.py \
--devices=8 --tp=4 --data_type=mock
torchrun --nproc_per_node=8 scripts/vlm/avlm_pretrain.py \
--devices=8 --tp=4 --data_type=energon --data_path='' \
--language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5
"""
import argparse
import torch
from lightning.pytorch.loggers import WandbLogger
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.enums import AttnBackend
from nemo import lightning as nl
from nemo.collections import avlm, llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.speechlm.modules.asr_module import ASRModuleConfig
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback
def main(args):
# pylint: disable=C0115,C0116
# Global and micro batch sizes
gbs = args.gbs
mbs = args.mbs
num_workers = args.num_workers
max_steps = args.max_steps
if args.sequence_parallel == "true":
args.sequence_parallel = True
elif args.sequence_parallel == "false":
args.sequence_parallel = False
else:
raise ValueError(f"Invalid sequence parallel value: {args.sequence_parallel}")
if args.use_packed_sequence == "true":
args.use_packed_sequence = True
elif args.use_packed_sequence == "false":
args.use_packed_sequence = False
else:
raise ValueError(f"Invalid use packed sequence value: {args.use_packed_sequence}")
decoder_seq_length = args.seq_length
if args.use_packed_sequence:
decoder_seq_length = int(args.seq_length * 2)
if args.data_type == "energon":
from nemo.collections.avlm.data.energon import AVLMDataModule, AVLMSampleConfig, AVLMTaskEncoder
data_path = args.data_path
avlm_sample_config = AVLMSampleConfig(
audio_encoder_config={ # whisper audio encoder
"model_type": "whisper",
"window_stride": 0.01,
"sample_rate": 16000,
"fixed_max_audio_length": 29.9999 * 16000,
"encoder_down_sampling": 2,
"num_mel_bins": None,
"patch_size": None,
"time_stride": None,
"frequency_stride": None,
"max_spectrogram_length": None,
},
image_encoder_config={
"model_type": "vit",
"img_width": 336,
"img_height": 336,
"patch_size": 14,
"projection_downsample_factor": None,
},
)
# Setting system prompt to empty string
avlm_sample_config.conversation_template_config.system = ''
task_encoder = AVLMTaskEncoder(
multimodal_sample_config=avlm_sample_config,
packed_sequence=args.use_packed_sequence,
packed_sequence_size=decoder_seq_length,
)
data = AVLMDataModule(
path=data_path,
num_workers=num_workers,
micro_batch_size=mbs,
global_batch_size=gbs,
seq_length=decoder_seq_length,
tokenizer=AutoTokenizer("llava-hf/llava-1.5-7b-hf"),
multimodal_sample_config=avlm_sample_config,
task_encoder=task_encoder,
packing_buffer_size=200 if args.use_packed_sequence else None,
)
elif args.data_type == "mock":
data = avlm.data.AVLMMockDataModule(
seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=AutoTokenizer("llava-hf/llava-1.5-7b-hf"),
image_processor=None,
audio_processor=None,
num_workers=num_workers,
image_embedding_tokens=576, # e.g. for CLIP-ViT-L-14-336
audio_embedding_tokens=1500, # e.g. for Whisper
)
else:
raise ValueError(f"Data type {args.data_type} not supported")
# Submodules configurations
language_transformer_config = llm.Llama2Config7B(
seq_length=decoder_seq_length,
attention_backend=AttnBackend.fused,
)
vision_transformer_config = vlm.HFCLIPVisionConfig(
pretrained_model_name_or_path="openai/clip-vit-large-patch14-336"
)
vision_model_from_pretrained = None
# vision_transformer_config = vlm.CLIPViTL_14_336_Config()
# vision_model_from_pretrained = "/root/.cache/nemo/models/openai/clip-vit-large-patch14"
vision_projection_config = vlm.MultimodalProjectorConfig(
projector_type=args.projector_type,
input_size=vision_transformer_config.hidden_size,
hidden_size=language_transformer_config.hidden_size,
ffn_hidden_size=language_transformer_config.hidden_size,
)
# whisper audio encoder # need update NeMo from Steve's branch
audio_transformer_config = ASRModuleConfig(
_target_="nemo.collections.speechlm.modules.asr_module.ASRModuleConfig",
use_hf_auto_model=True,
hf_trust_remote_code=False,
hf_load_pretrained_weights=True,
pretrained_model="openai/whisper-large-v3",
hidden_size=1280,
target_module="model.encoder",
)
audio_projection_config = vlm.MultimodalProjectorConfig(
projector_type=args.projector_type,
input_size=audio_transformer_config.hidden_size, # need to set somehow?
hidden_size=language_transformer_config.hidden_size,
ffn_hidden_size=language_transformer_config.hidden_size,
)
# AVLM model configuration
avlm_config = avlm.AVLMConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
audio_transformer_config=audio_transformer_config,
audio_projection_config=audio_projection_config,
language_model_from_pretrained=args.language_model_path,
vision_model_from_pretrained=vision_model_from_pretrained,
audio_model_from_pretrained=None,
freeze_language_model=True,
freeze_vision_model=True,
freeze_vision_projection=False,
freeze_audio_model=True,
freeze_audio_projection=False,
)
model = avlm.AVLMModel(avlm_config, tokenizer=data.tokenizer)
# Training strategy setup
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp_size,
pipeline_model_parallel_size=args.pp_size,
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=args.sequence_parallel,
ckpt_async_save=True,
)
# Checkpoint callback setup
checkpoint_callback = nl.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
save_top_k=5,
every_n_train_steps=5000,
dirpath=args.log_dir,
)
# Trainer setup
trainer = nl.Trainer(
num_nodes=args.num_nodes,
devices=args.devices,
max_steps=max_steps,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[checkpoint_callback, TimingCallback()],
val_check_interval=args.val_check_interval,
check_val_every_n_epoch=None,
# limit_val_batches=1.0,
limit_val_batches=20,
log_every_n_steps=1,
num_sanity_val_steps=0,
)
# Logger setup
nemo_logger = nl.NeMoLogger(
log_dir=args.log_dir,
name=args.name,
wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None,
)
# Auto resume setup
resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
resume_from_directory=args.log_dir,
restore_config=(
nl.RestoreConfig(path=args.restore_path, load_optim_state=False) if args.restore_path is not None else None
),
)
# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=args.lr,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=True,
bf16=True,
clip_grad=1.0,
)
sched = CosineAnnealingScheduler(
max_steps=trainer.max_steps,
warmup_steps=150,
constant_steps=0,
min_lr=2.0e-05,
)
opt = MegatronOptimizerModule(opt_config, sched)
llm.pretrain(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
optim=opt,
resume=resume,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AVLM Pretraining Script")
# Argument parsing
parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon")
parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file")
parser.add_argument(
"--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints"
)
parser.add_argument(
"--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model"
)
parser.add_argument(
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
)
parser.add_argument("--devices", type=int, required=False, default=1)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument("--max_steps", type=int, required=False, default=2000)
parser.add_argument("--val_check_interval", type=int, required=False, default=500)
parser.add_argument("--seq_length", type=int, required=False, default=8192)
parser.add_argument("--tp_size", type=int, required=False, default=1)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
parser.add_argument("--cp_size", type=int, required=False, default=1)
parser.add_argument(
"--sequence_parallel", type=str, required=False, default="false", help="Enable sequence parallel"
)
parser.add_argument(
"--use_packed_sequence", type=str, required=False, default="false", help="Enable sequence packing"
)
parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu")
parser.add_argument("--name", type=str, required=False, default="avlm_pretrain")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size")
parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size")
parser.add_argument(
"--num_workers", type=int, required=False, default=32, help="Number of workers for data loading"
)
parser.add_argument("--lr", type=float, required=False, default=0.001, help="Learning rate")
args = parser.parse_args()
main(args)