xuehongyang
ser
83d8d3c
import os
import sys
import torch
from loguru import logger
from configs.train_config import TrainConfig
from data.dataset import TrainDatasetDataLoader
from models.model import HifiFace
from utils.visualizer import Visualizer
use_ddp = TrainConfig().use_ddp
if use_ddp:
import torch.distributed as dist
def setup():
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = "12345"
dist.init_process_group("nccl") # , rank=rank, world_size=world_size)
return dist.get_rank()
def cleanup():
dist.destroy_process_group()
def train():
rank = 0
if use_ddp:
rank = setup()
device = torch.device(f"cuda:{rank}")
logger.info(f"use device {device}")
opt = TrainConfig()
dataloader = TrainDatasetDataLoader()
dataset_length = len(dataloader)
logger.info(f"Dataset length: {dataset_length}")
model = HifiFace(
opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint
)
model.train()
logger.info("model initialized")
visualizer = None
ckpt = False
if not opt.use_ddp or rank == 0:
visualizer = Visualizer(opt)
ckpt = True
total_iter = 0
epoch = 0
while True:
if opt.use_ddp:
dataloader.train_sampler.set_epoch(epoch)
for data in dataloader:
source_image = data["source_image"].to(device)
target_image = data["target_image"].to(device)
targe_mask = data["target_mask"].to(device)
same = data["same"].to(device)
loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same)
total_iter += 1
if total_iter % opt.visualize_interval == 0 and visualizer is not None:
visualizer.display_current_results(total_iter, visual_dict)
if total_iter % opt.plot_interval == 0 and visualizer is not None:
visualizer.plot_current_losses(total_iter, loss_dict)
logger.info(f"Iter: {total_iter}")
for k, v in loss_dict.items():
logger.info(f" {k}: {v}")
logger.info("=" * 20)
if total_iter % opt.checkpoint_interval == 0 and ckpt:
logger.info(f"Saving model at iter {total_iter}")
model.save(opt.checkpoint_dir, total_iter)
if total_iter > opt.max_iters:
logger.info(f"Maximum iterations exceeded. Stopping training.")
if ckpt:
model.save(opt.checkpoint_dir, total_iter)
if use_ddp:
cleanup()
sys.exit(0)
epoch += 1
if __name__ == "__main__":
if use_ddp:
# CUDA_VISIBLE_DEVICES=2,3 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 -m entry.train
os.environ["OMP_NUM_THREADS"] = "1"
n_gpus = torch.cuda.device_count()
train()
else:
train()