import os import sys import torch from loguru import logger from configs.train_config import TrainConfig from data.dataset import TrainDatasetDataLoader from models.model import HifiFace from utils.visualizer import Visualizer use_ddp = TrainConfig().use_ddp if use_ddp: import torch.distributed as dist def setup(): # os.environ["MASTER_ADDR"] = "localhost" # os.environ["MASTER_PORT"] = "12345" dist.init_process_group("nccl") # , rank=rank, world_size=world_size) return dist.get_rank() def cleanup(): dist.destroy_process_group() def train(): rank = 0 if use_ddp: rank = setup() device = torch.device(f"cuda:{rank}") logger.info(f"use device {device}") opt = TrainConfig() dataloader = TrainDatasetDataLoader() dataset_length = len(dataloader) logger.info(f"Dataset length: {dataset_length}") model = HifiFace( opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint ) model.train() logger.info("model initialized") visualizer = None ckpt = False if not opt.use_ddp or rank == 0: visualizer = Visualizer(opt) ckpt = True total_iter = 0 epoch = 0 while True: if opt.use_ddp: dataloader.train_sampler.set_epoch(epoch) for data in dataloader: source_image = data["source_image"].to(device) target_image = data["target_image"].to(device) targe_mask = data["target_mask"].to(device) same = data["same"].to(device) loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same) total_iter += 1 if total_iter % opt.visualize_interval == 0 and visualizer is not None: visualizer.display_current_results(total_iter, visual_dict) if total_iter % opt.plot_interval == 0 and visualizer is not None: visualizer.plot_current_losses(total_iter, loss_dict) logger.info(f"Iter: {total_iter}") for k, v in loss_dict.items(): logger.info(f" {k}: {v}") logger.info("=" * 20) if total_iter % opt.checkpoint_interval == 0 and ckpt: logger.info(f"Saving model at iter {total_iter}") model.save(opt.checkpoint_dir, total_iter) if total_iter > opt.max_iters: logger.info(f"Maximum iterations exceeded. Stopping training.") if ckpt: model.save(opt.checkpoint_dir, total_iter) if use_ddp: cleanup() sys.exit(0) epoch += 1 if __name__ == "__main__": if use_ddp: # CUDA_VISIBLE_DEVICES=2,3 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 -m entry.train os.environ["OMP_NUM_THREADS"] = "1" n_gpus = torch.cuda.device_count() train() else: train()