"""Training script for TiTok. Copyright (2024) Bytedance Ltd. and/or its affiliates Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Reference: https://github.com/huggingface/open-muse """ import math import os from pathlib import Path from accelerate.utils import set_seed from accelerate import Accelerator import torch from omegaconf import OmegaConf from utils.logger import setup_logger from utils.train_utils import ( get_config, create_pretrained_tokenizer, create_model_and_loss_module, create_optimizer, create_lr_scheduler, create_dataloader, create_evaluator, auto_resume, save_checkpoint, train_one_epoch) def main(): workspace = os.environ.get('WORKSPACE', '') if workspace: torch.hub.set_dir(workspace + "/models/hub") config = get_config() # Enable TF32 on Ampere GPUs. if config.training.enable_tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False output_dir = config.experiment.output_dir os.makedirs(output_dir, exist_ok=True) config.experiment.logging_dir = os.path.join(output_dir, "logs") # Whether logging to Wandb or Tensorboard. tracker = "tensorboard" if config.training.enable_wandb: tracker = "wandb" accelerator = Accelerator( gradient_accumulation_steps=config.training.gradient_accumulation_steps, mixed_precision=config.training.mixed_precision, log_with=tracker, project_dir=config.experiment.logging_dir, split_batches=False, ) logger = setup_logger(name="TiTok", log_level="INFO", output_file=f"{output_dir}/log{accelerator.process_index}.txt") # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers(config.experiment.name) config_path = Path(output_dir) / "config.yaml" logger.info(f"Saving config to {config_path}") OmegaConf.save(config, config_path) logger.info(f"Config:\n{OmegaConf.to_yaml(config)}") # If passed along, set the training seed now. if config.training.seed is not None: set_seed(config.training.seed, device_specific=True) if accelerator.local_process_index == 0: # download the maskgit-vq tokenizer weight from huggingface_hub import hf_hub_download hf_hub_download(repo_id="fun-research/TiTok", filename=f"{config.model.vq_model.pretrained_tokenizer_weight}", local_dir="./") accelerator.wait_for_everyone() pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator) model, ema_model, loss_module = create_model_and_loss_module( config, logger, accelerator, model_type="titok") optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module) lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler( config, logger, accelerator, optimizer, discriminator_optimizer) train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator) # Set up evaluator. evaluator = create_evaluator(config, logger, accelerator) # Prepare everything with accelerator. logger.info("Preparing model, optimizer and dataloaders") # The dataloader are already aware of distributed training, so we don't need to prepare them. if config.model.vq_model.finetune_decoder: model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler ) else: model, optimizer, lr_scheduler = accelerator.prepare( model, optimizer, lr_scheduler ) if config.training.use_ema: ema_model.to(accelerator.device) total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes num_batches = math.ceil( config.experiment.max_train_examples / total_batch_size_without_accum) num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps) num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) # Start training. logger.info("***** Running training *****") logger.info(f" Num training steps = {config.training.max_train_steps}") logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}") logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {( config.training.per_gpu_batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps)}""") global_step = 0 first_epoch = 0 global_step, first_epoch = auto_resume( config, logger, accelerator, ema_model, num_update_steps_per_epoch, strict=True) for current_epoch in range(first_epoch, num_train_epochs): accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.") global_step = train_one_epoch(config, logger, accelerator, model, ema_model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler, train_dataloader, eval_dataloader, evaluator, global_step, pretrained_tokenizer=pretrained_tokenizer) # Stop training if max steps is reached. if global_step >= config.training.max_train_steps: accelerator.print( f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" ) break accelerator.wait_for_everyone() # Save checkpoint at the end of training. save_checkpoint(model, output_dir, accelerator, global_step, logger=logger) # Save the final trained checkpoint if accelerator.is_main_process: model = accelerator.unwrap_model(model) if config.training.use_ema: ema_model.copy_to(model.parameters()) model.save_pretrained_weight(output_dir) accelerator.end_training() if __name__ == "__main__": main()