#!/usr/bin/env python # coding=utf-8 # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math import os import sys import datasets import numpy as np import torch import transformers from aac_metrics import evaluate from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.logging import get_logger from accelerate.utils import set_seed from datasets import load_dataset from omegaconf import OmegaConf from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import ( AutoTokenizer, BartConfig, get_inverse_sqrt_schedule, get_scheduler, ) from data.collator import DataCollatorForEnClapBart from data.preprocess import Preprocessor from modeling.enclap_bart import EnClapBartForConditionalGeneration logger = get_logger(__name__) metric_list = ["meteor", "spider"] def main(): # Load Configuration cfg_path = sys.argv[1] args = OmegaConf.load(cfg_path) # Initialize Logging accelerator_log_kwargs = {} ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) if args.with_tracking: accelerator_log_kwargs["log_with"] = args.report_to accelerator_log_kwargs["project_dir"] = args.output_dir # Initialize Accelerator accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, split_batches=args.split_batches, kwargs_handlers=[ddp_kwargs], **accelerator_log_kwargs, ) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) with open(os.path.join(args.output_dir, "args.yaml"), "w") as f: OmegaConf.save(args, f) accelerator.wait_for_everyone() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt")) logger.logger.addHandler(file_handler) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Get the datasets data_files = {} data_files_eval = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files_eval["validation"] = args.validation_file extension = args.train_file.split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets_eval = load_dataset(extension, data_files=data_files_eval) # Load pretrained model and tokenizer tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) if args.config_name_or_path is not None: config = BartConfig.from_pretrained(args.config_name_or_path) else: config = None if args.model_name_or_path is not None: if config is None: model = EnClapBartForConditionalGeneration.from_pretrained( args.model_name_or_path ) else: model = EnClapBartForConditionalGeneration.from_pretrained( args.model_name_or_path, config=config ) else: model = EnClapBartForConditionalGeneration(config=config) # Set the generation config if args.val_max_target_length is None: args.val_max_target_length = args.max_target_length # Set max encodec length based on the shape of the positional encoding max_encodec_length = model.config.max_position_embeddings - 2 label_pad_token_id = ( -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id ) preprocessor = Preprocessor( args.encodec_base_path, args.clap_base_path, tokenizer, model.config.max_position_embeddings, args.encodec_masking_prob, args.encodec_masking_span, label_pad_token_id, model.config.encodec_vocab_size, args.eval_num_captions, ) with accelerator.main_process_first(): train_dataset = raw_datasets["train"].map( preprocessor.preprocess_train, num_proc=args.preprocessing_num_workers, load_from_cache_file=not args.overwrite_cache, desc="Running tokenizer on dataset", ) train_dataset.set_format( "pt", columns=[ "input_ids", "attention_mask", "clap", "labels", "decoder_attention_mask", ], ) # Temporarily set max_target_length for validation. eval_dataset = raw_datasets_eval["validation"].map( preprocessor.preprocess_eval, num_proc=args.preprocessing_num_workers, load_from_cache_file=not args.overwrite_cache, desc="Running tokenizer on dataset", ) eval_dataset.set_format( "pt", columns=["input_ids", "attention_mask", "clap"], output_all_columns=True, ) train_data_collator = DataCollatorForEnClapBart( tokenizer=tokenizer, model=model, return_tensors="pt", label_pad_token_id=label_pad_token_id, max_length=max_encodec_length, encodec_masking_prob=args.encodec_masking_prob, encodec_masking_span=args.encodec_masking_span, ) valid_data_collator = DataCollatorForEnClapBart( tokenizer=tokenizer, model=model, return_tensors="pt", label_pad_token_id=label_pad_token_id, max_length=max_encodec_length, ) train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=train_data_collator, batch_size=args.per_device_train_batch_size, ) eval_dataloader = DataLoader( eval_dataset, collate_fn=valid_data_collator, batch_size=args.per_device_eval_batch_size, ) # Optimizer # Split weights in two groups, one with weight decay and the other not. no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"): lr_scheduler = get_inverse_sqrt_schedule( optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, timescale=args.time_scale, ) else: lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) # Prepare everything with our `accelerator`. ( model, optimizer, train_dataloader, eval_dataloader, lr_scheduler, ) = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Figure out how many steps we should save the Accelerator states checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps) # The trackers initializes automatically on the main process. if args.with_tracking: accelerator.init_trackers(args.logging_dir) # Train! total_batch_size = ( args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps ) if args.split_batches: total_batch_size = int(total_batch_size / accelerator.num_processes) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info( f" Instantaneous batch size per device = {args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") completed_steps = 0 starting_epoch = 0 # Potentially load in the weights and states from a previous save if not args.overwrite_output_dir and os.path.exists( os.path.join(args.output_dir, "checkpoints") ): if args.resume_from_checkpoint is not None: accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.load_state(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [ f for f in os.scandir(os.path.join(args.output_dir, "checkpoints")) if f.is_dir() ] dirs.sort(key=os.path.getctime) path = dirs[ -1 ].name # Sorts folders by date modified, most recent checkpoint is the last accelerator.print(f"Resumed from checkpoint: {dirs[-1]}") accelerator.load_state(dirs[-1]) # Extract `epoch_{i}` or `step_{i}` training_difference = os.path.splitext(path)[0] if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps resume_step = ( int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps ) starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint if args.with_tracking: total_loss = 0 logging_loss = 0 before_epoch_loss = 0 if args.encodec_masking_prob > 0: total_encodec_loss = 0 logging_encodec_loss = 0 before_epoch_encodec_loss = 0 for epoch in range(starting_epoch, args.num_train_epochs): model.train() if ( args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None ): # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches( train_dataloader, resume_step ) else: active_dataloader = train_dataloader logger.info(f"***** Running epoch {epoch} *****") epoch_iterator = tqdm( active_dataloader, desc="Training", disable=not accelerator.is_local_main_process, dynamic_ncols=True, colour="CYAN", ) for step, batch in enumerate(epoch_iterator): with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss # We keep track of the loss at each epoch if args.with_tracking: total_loss += outputs.lm_loss.item() if args.encodec_masking_prob > 0: if outputs.encodec_loss is not None: total_encodec_loss += outputs.encodec_loss.item() accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_( model.parameters(), max_norm=args.max_grad_norm ) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: completed_steps += 1 # Add loss information to tqdm epoch_iterator.set_postfix(loss=total_loss / completed_steps) if completed_steps % args.logging_steps == 0: train_log = { "train/learning_rate": lr_scheduler.get_last_lr()[0] } train_log["train/loss"] = ( total_loss - logging_loss ) / args.logging_steps logging_loss = total_loss if args.encodec_masking_prob > 0: train_log["train/encodec_loss"] = ( total_encodec_loss - logging_encodec_loss ) / args.logging_steps logging_encodec_loss = total_encodec_loss accelerator.log(train_log, step=completed_steps) if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: output_dir = f"step_{completed_steps }" if args.output_dir is not None: output_dir = os.path.join( args.output_dir, "checkpoints", output_dir ) accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break model.eval() gen_kwargs = { "max_length": args.val_max_target_length, } predictions = [] references = [] eval_iterator = tqdm( eval_dataloader, desc="Validation", disable=not accelerator.is_local_main_process, dynamic_ncols=True, colour="MAGENTA", ) for step, batch in enumerate(eval_iterator): # Drop the padded samples of the last batch of dataloader # try: # if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0: # batch = batch[:accelerator.gradient_state.remainder] # except: # pass with torch.no_grad(): batch["input_ids"] = batch["input_ids"].cuda() batch["clap"] = batch["clap"].cuda() batch["attention_mask"] = batch["attention_mask"].cuda() batch["eos_mask"] = batch["eos_mask"].cuda() generated_tokens = accelerator.unwrap_model(model).generate( batch["input_ids"], clap=batch["clap"], attention_mask=batch["attention_mask"], eos_mask=batch["eos_mask"], **gen_kwargs, ) generated_tokens = accelerator.pad_across_processes( generated_tokens, dim=1, pad_index=tokenizer.pad_token_id ) generated_tokens = generated_tokens.cpu().numpy() captions = batch["captions"] if isinstance(generated_tokens, tuple): generated_tokens = generated_tokens[0] decoded_preds = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) predictions.extend(decoded_preds) references.extend(captions) logger.info("Evaluating predictions...") result = evaluate(predictions, references, metrics=metric_list) # Gather Result result = {k: v.cuda() for k, v in result[0].items()} result = accelerator.gather_for_metrics(result) # Log the average of metrics among the processes if accelerator.num_processes > 1: result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()} else: result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()} logger.info(result) if args.with_tracking: result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len( train_dataloader ) result["train/steps"] = completed_steps before_epoch_loss = total_loss if args.encodec_masking_prob > 0: result["train/epoch_encodec_loss"] = ( total_encodec_loss - before_epoch_encodec_loss ) / len(train_dataloader) before_epoch_encodec_loss = total_encodec_loss accelerator.log(result, step=epoch) if args.checkpointing_steps == "epoch": output_dir = f"epoch_{epoch}" if args.output_dir is not None: output_dir = os.path.join(args.output_dir, "checkpoints", output_dir) accelerator.save_state(output_dir) if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.config.save_pretrained(output_dir) if args.output_dir is not None: save_dir = os.path.join(args.output_dir, "final") accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( save_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, ) if accelerator.is_main_process: tokenizer.save_pretrained(save_dir) if __name__ == "__main__": main()