Spaces:
Runtime error
Runtime error
| """ | |
| # First update `train_config.py` to set paths to your dataset locations. | |
| # You may want to change `--num-workers` according to your machine's memory. | |
| # The default num-workers=8 may cause dataloader to exit unexpectedly when | |
| # machine is out of memory. | |
| # Stage 1 | |
| python train.py \ | |
| --model-variant mobilenetv3 \ | |
| --dataset videomatte \ | |
| --resolution-lr 512 \ | |
| --seq-length-lr 15 \ | |
| --learning-rate-backbone 0.0001 \ | |
| --learning-rate-aspp 0.0002 \ | |
| --learning-rate-decoder 0.0002 \ | |
| --learning-rate-refiner 0 \ | |
| --checkpoint-dir checkpoint/stage1 \ | |
| --log-dir log/stage1 \ | |
| --epoch-start 0 \ | |
| --epoch-end 20 | |
| # Stage 2 | |
| python train.py \ | |
| --model-variant mobilenetv3 \ | |
| --dataset videomatte \ | |
| --resolution-lr 512 \ | |
| --seq-length-lr 50 \ | |
| --learning-rate-backbone 0.00005 \ | |
| --learning-rate-aspp 0.0001 \ | |
| --learning-rate-decoder 0.0001 \ | |
| --learning-rate-refiner 0 \ | |
| --checkpoint checkpoint/stage1/epoch-19.pth \ | |
| --checkpoint-dir checkpoint/stage2 \ | |
| --log-dir log/stage2 \ | |
| --epoch-start 20 \ | |
| --epoch-end 22 | |
| # Stage 3 | |
| python train.py \ | |
| --model-variant mobilenetv3 \ | |
| --dataset videomatte \ | |
| --train-hr \ | |
| --resolution-lr 512 \ | |
| --resolution-hr 2048 \ | |
| --seq-length-lr 40 \ | |
| --seq-length-hr 6 \ | |
| --learning-rate-backbone 0.00001 \ | |
| --learning-rate-aspp 0.00001 \ | |
| --learning-rate-decoder 0.00001 \ | |
| --learning-rate-refiner 0.0002 \ | |
| --checkpoint checkpoint/stage2/epoch-21.pth \ | |
| --checkpoint-dir checkpoint/stage3 \ | |
| --log-dir log/stage3 \ | |
| --epoch-start 22 \ | |
| --epoch-end 23 | |
| # Stage 4 | |
| python train.py \ | |
| --model-variant mobilenetv3 \ | |
| --dataset imagematte \ | |
| --train-hr \ | |
| --resolution-lr 512 \ | |
| --resolution-hr 2048 \ | |
| --seq-length-lr 40 \ | |
| --seq-length-hr 6 \ | |
| --learning-rate-backbone 0.00001 \ | |
| --learning-rate-aspp 0.00001 \ | |
| --learning-rate-decoder 0.00005 \ | |
| --learning-rate-refiner 0.0002 \ | |
| --checkpoint checkpoint/stage3/epoch-22.pth \ | |
| --checkpoint-dir checkpoint/stage4 \ | |
| --log-dir log/stage4 \ | |
| --epoch-start 23 \ | |
| --epoch-end 28 | |
| """ | |
| import argparse | |
| import torch | |
| import random | |
| import os | |
| from torch import nn | |
| from torch import distributed as dist | |
| from torch import multiprocessing as mp | |
| from torch.nn import functional as F | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.optim import Adam | |
| from torch.cuda.amp import autocast, GradScaler | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision.utils import make_grid | |
| from torchvision.transforms.functional import center_crop | |
| from tqdm import tqdm | |
| from dataset.videomatte import ( | |
| VideoMatteDataset, | |
| VideoMatteTrainAugmentation, | |
| VideoMatteValidAugmentation, | |
| ) | |
| from dataset.imagematte import ( | |
| ImageMatteDataset, | |
| ImageMatteAugmentation | |
| ) | |
| from dataset.coco import ( | |
| CocoPanopticDataset, | |
| CocoPanopticTrainAugmentation, | |
| ) | |
| from dataset.spd import ( | |
| SuperviselyPersonDataset | |
| ) | |
| from dataset.youtubevis import ( | |
| YouTubeVISDataset, | |
| YouTubeVISAugmentation | |
| ) | |
| from dataset.augmentation import ( | |
| TrainFrameSampler, | |
| ValidFrameSampler | |
| ) | |
| from model import MattingNetwork | |
| from train_config import DATA_PATHS | |
| from train_loss import matting_loss, segmentation_loss | |
| class Trainer: | |
| def __init__(self, rank, world_size): | |
| self.parse_args() | |
| self.init_distributed(rank, world_size) | |
| self.init_datasets() | |
| self.init_model() | |
| self.init_writer() | |
| self.train() | |
| self.cleanup() | |
| def parse_args(self): | |
| parser = argparse.ArgumentParser() | |
| # Model | |
| parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) | |
| # Matting dataset | |
| parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte']) | |
| # Learning rate | |
| parser.add_argument('--learning-rate-backbone', type=float, required=True) | |
| parser.add_argument('--learning-rate-aspp', type=float, required=True) | |
| parser.add_argument('--learning-rate-decoder', type=float, required=True) | |
| parser.add_argument('--learning-rate-refiner', type=float, required=True) | |
| # Training setting | |
| parser.add_argument('--train-hr', action='store_true') | |
| parser.add_argument('--resolution-lr', type=int, default=512) | |
| parser.add_argument('--resolution-hr', type=int, default=2048) | |
| parser.add_argument('--seq-length-lr', type=int, required=True) | |
| parser.add_argument('--seq-length-hr', type=int, default=6) | |
| parser.add_argument('--downsample-ratio', type=float, default=0.25) | |
| parser.add_argument('--batch-size-per-gpu', type=int, default=1) | |
| parser.add_argument('--num-workers', type=int, default=8) | |
| parser.add_argument('--epoch-start', type=int, default=0) | |
| parser.add_argument('--epoch-end', type=int, default=16) | |
| # Tensorboard logging | |
| parser.add_argument('--log-dir', type=str, required=True) | |
| parser.add_argument('--log-train-loss-interval', type=int, default=20) | |
| parser.add_argument('--log-train-images-interval', type=int, default=500) | |
| # Checkpoint loading and saving | |
| parser.add_argument('--checkpoint', type=str) | |
| parser.add_argument('--checkpoint-dir', type=str, required=True) | |
| parser.add_argument('--checkpoint-save-interval', type=int, default=500) | |
| # Distributed | |
| parser.add_argument('--distributed-addr', type=str, default='localhost') | |
| parser.add_argument('--distributed-port', type=str, default='12355') | |
| # Debugging | |
| parser.add_argument('--disable-progress-bar', action='store_true') | |
| parser.add_argument('--disable-validation', action='store_true') | |
| parser.add_argument('--disable-mixed-precision', action='store_true') | |
| self.args = parser.parse_args() | |
| def init_distributed(self, rank, world_size): | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.log('Initializing distributed') | |
| os.environ['MASTER_ADDR'] = self.args.distributed_addr | |
| os.environ['MASTER_PORT'] = self.args.distributed_port | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| def init_datasets(self): | |
| self.log('Initializing matting datasets') | |
| size_hr = (self.args.resolution_hr, self.args.resolution_hr) | |
| size_lr = (self.args.resolution_lr, self.args.resolution_lr) | |
| # Matting datasets: | |
| if self.args.dataset == 'videomatte': | |
| self.dataset_lr_train = VideoMatteDataset( | |
| videomatte_dir=DATA_PATHS['videomatte']['train'], | |
| background_image_dir=DATA_PATHS['background_images']['train'], | |
| background_video_dir=DATA_PATHS['background_videos']['train'], | |
| size=self.args.resolution_lr, | |
| seq_length=self.args.seq_length_lr, | |
| seq_sampler=TrainFrameSampler(), | |
| transform=VideoMatteTrainAugmentation(size_lr)) | |
| if self.args.train_hr: | |
| self.dataset_hr_train = VideoMatteDataset( | |
| videomatte_dir=DATA_PATHS['videomatte']['train'], | |
| background_image_dir=DATA_PATHS['background_images']['train'], | |
| background_video_dir=DATA_PATHS['background_videos']['train'], | |
| size=self.args.resolution_hr, | |
| seq_length=self.args.seq_length_hr, | |
| seq_sampler=TrainFrameSampler(), | |
| transform=VideoMatteTrainAugmentation(size_hr)) | |
| self.dataset_valid = VideoMatteDataset( | |
| videomatte_dir=DATA_PATHS['videomatte']['valid'], | |
| background_image_dir=DATA_PATHS['background_images']['valid'], | |
| background_video_dir=DATA_PATHS['background_videos']['valid'], | |
| size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, | |
| seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, | |
| seq_sampler=ValidFrameSampler(), | |
| transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) | |
| else: | |
| self.dataset_lr_train = ImageMatteDataset( | |
| imagematte_dir=DATA_PATHS['imagematte']['train'], | |
| background_image_dir=DATA_PATHS['background_images']['train'], | |
| background_video_dir=DATA_PATHS['background_videos']['train'], | |
| size=self.args.resolution_lr, | |
| seq_length=self.args.seq_length_lr, | |
| seq_sampler=TrainFrameSampler(), | |
| transform=ImageMatteAugmentation(size_lr)) | |
| if self.args.train_hr: | |
| self.dataset_hr_train = ImageMatteDataset( | |
| imagematte_dir=DATA_PATHS['imagematte']['train'], | |
| background_image_dir=DATA_PATHS['background_images']['train'], | |
| background_video_dir=DATA_PATHS['background_videos']['train'], | |
| size=self.args.resolution_hr, | |
| seq_length=self.args.seq_length_hr, | |
| seq_sampler=TrainFrameSampler(), | |
| transform=ImageMatteAugmentation(size_hr)) | |
| self.dataset_valid = ImageMatteDataset( | |
| imagematte_dir=DATA_PATHS['imagematte']['valid'], | |
| background_image_dir=DATA_PATHS['background_images']['valid'], | |
| background_video_dir=DATA_PATHS['background_videos']['valid'], | |
| size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, | |
| seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, | |
| seq_sampler=ValidFrameSampler(), | |
| transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) | |
| # Matting dataloaders: | |
| self.datasampler_lr_train = DistributedSampler( | |
| dataset=self.dataset_lr_train, | |
| rank=self.rank, | |
| num_replicas=self.world_size, | |
| shuffle=True) | |
| self.dataloader_lr_train = DataLoader( | |
| dataset=self.dataset_lr_train, | |
| batch_size=self.args.batch_size_per_gpu, | |
| num_workers=self.args.num_workers, | |
| sampler=self.datasampler_lr_train, | |
| pin_memory=True) | |
| if self.args.train_hr: | |
| self.datasampler_hr_train = DistributedSampler( | |
| dataset=self.dataset_hr_train, | |
| rank=self.rank, | |
| num_replicas=self.world_size, | |
| shuffle=True) | |
| self.dataloader_hr_train = DataLoader( | |
| dataset=self.dataset_hr_train, | |
| batch_size=self.args.batch_size_per_gpu, | |
| num_workers=self.args.num_workers, | |
| sampler=self.datasampler_hr_train, | |
| pin_memory=True) | |
| self.dataloader_valid = DataLoader( | |
| dataset=self.dataset_valid, | |
| batch_size=self.args.batch_size_per_gpu, | |
| num_workers=self.args.num_workers, | |
| pin_memory=True) | |
| # Segementation datasets | |
| self.log('Initializing image segmentation datasets') | |
| self.dataset_seg_image = ConcatDataset([ | |
| CocoPanopticDataset( | |
| imgdir=DATA_PATHS['coco_panoptic']['imgdir'], | |
| anndir=DATA_PATHS['coco_panoptic']['anndir'], | |
| annfile=DATA_PATHS['coco_panoptic']['annfile'], | |
| transform=CocoPanopticTrainAugmentation(size_lr)), | |
| SuperviselyPersonDataset( | |
| imgdir=DATA_PATHS['spd']['imgdir'], | |
| segdir=DATA_PATHS['spd']['segdir'], | |
| transform=CocoPanopticTrainAugmentation(size_lr)) | |
| ]) | |
| self.datasampler_seg_image = DistributedSampler( | |
| dataset=self.dataset_seg_image, | |
| rank=self.rank, | |
| num_replicas=self.world_size, | |
| shuffle=True) | |
| self.dataloader_seg_image = DataLoader( | |
| dataset=self.dataset_seg_image, | |
| batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, | |
| num_workers=self.args.num_workers, | |
| sampler=self.datasampler_seg_image, | |
| pin_memory=True) | |
| self.log('Initializing video segmentation datasets') | |
| self.dataset_seg_video = YouTubeVISDataset( | |
| videodir=DATA_PATHS['youtubevis']['videodir'], | |
| annfile=DATA_PATHS['youtubevis']['annfile'], | |
| size=self.args.resolution_lr, | |
| seq_length=self.args.seq_length_lr, | |
| seq_sampler=TrainFrameSampler(speed=[1]), | |
| transform=YouTubeVISAugmentation(size_lr)) | |
| self.datasampler_seg_video = DistributedSampler( | |
| dataset=self.dataset_seg_video, | |
| rank=self.rank, | |
| num_replicas=self.world_size, | |
| shuffle=True) | |
| self.dataloader_seg_video = DataLoader( | |
| dataset=self.dataset_seg_video, | |
| batch_size=self.args.batch_size_per_gpu, | |
| num_workers=self.args.num_workers, | |
| sampler=self.datasampler_seg_video, | |
| pin_memory=True) | |
| def init_model(self): | |
| self.log('Initializing model') | |
| self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank) | |
| if self.args.checkpoint: | |
| self.log(f'Restoring from checkpoint: {self.args.checkpoint}') | |
| self.log(self.model.load_state_dict( | |
| torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}'))) | |
| self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) | |
| self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True) | |
| self.optimizer = Adam([ | |
| {'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone}, | |
| {'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp}, | |
| {'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder}, | |
| {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder}, | |
| {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder}, | |
| {'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner}, | |
| ]) | |
| self.scaler = GradScaler() | |
| def init_writer(self): | |
| if self.rank == 0: | |
| self.log('Initializing writer') | |
| self.writer = SummaryWriter(self.args.log_dir) | |
| def train(self): | |
| for epoch in range(self.args.epoch_start, self.args.epoch_end): | |
| self.epoch = epoch | |
| self.step = epoch * len(self.dataloader_lr_train) | |
| if not self.args.disable_validation: | |
| self.validate() | |
| self.log(f'Training epoch: {epoch}') | |
| for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True): | |
| # Low resolution pass | |
| self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr') | |
| # High resolution pass | |
| if self.args.train_hr: | |
| true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample() | |
| self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') | |
| # Segmentation pass | |
| if self.step % 2 == 0: | |
| true_img, true_seg = self.load_next_seg_video_sample() | |
| self.train_seg(true_img, true_seg, log_label='seg_video') | |
| else: | |
| true_img, true_seg = self.load_next_seg_image_sample() | |
| self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') | |
| if self.step % self.args.checkpoint_save_interval == 0: | |
| self.save() | |
| self.step += 1 | |
| def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): | |
| true_fgr = true_fgr.to(self.rank, non_blocking=True) | |
| true_pha = true_pha.to(self.rank, non_blocking=True) | |
| true_bgr = true_bgr.to(self.rank, non_blocking=True) | |
| true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) | |
| true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) | |
| with autocast(enabled=not self.args.disable_mixed_precision): | |
| pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] | |
| loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) | |
| self.scaler.scale(loss['total']).backward() | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0: | |
| for loss_name, loss_value in loss.items(): | |
| self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step) | |
| if self.rank == 0 and self.step % self.args.log_train_images_interval == 0: | |
| self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step) | |
| self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step) | |
| self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step) | |
| self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) | |
| self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) | |
| def train_seg(self, true_img, true_seg, log_label): | |
| true_img = true_img.to(self.rank, non_blocking=True) | |
| true_seg = true_seg.to(self.rank, non_blocking=True) | |
| true_img, true_seg = self.random_crop(true_img, true_seg) | |
| with autocast(enabled=not self.args.disable_mixed_precision): | |
| pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] | |
| loss = segmentation_loss(pred_seg, true_seg) | |
| self.scaler.scale(loss).backward() | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: | |
| self.writer.add_scalar(f'{log_label}_loss', loss, self.step) | |
| if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: | |
| self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) | |
| self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) | |
| self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) | |
| def load_next_mat_hr_sample(self): | |
| try: | |
| sample = next(self.dataiterator_mat_hr) | |
| except: | |
| self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1) | |
| self.dataiterator_mat_hr = iter(self.dataloader_hr_train) | |
| sample = next(self.dataiterator_mat_hr) | |
| return sample | |
| def load_next_seg_video_sample(self): | |
| try: | |
| sample = next(self.dataiterator_seg_video) | |
| except: | |
| self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) | |
| self.dataiterator_seg_video = iter(self.dataloader_seg_video) | |
| sample = next(self.dataiterator_seg_video) | |
| return sample | |
| def load_next_seg_image_sample(self): | |
| try: | |
| sample = next(self.dataiterator_seg_image) | |
| except: | |
| self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) | |
| self.dataiterator_seg_image = iter(self.dataloader_seg_image) | |
| sample = next(self.dataiterator_seg_image) | |
| return sample | |
| def validate(self): | |
| if self.rank == 0: | |
| self.log(f'Validating at the start of epoch: {self.epoch}') | |
| self.model_ddp.eval() | |
| total_loss, total_count = 0, 0 | |
| with torch.no_grad(): | |
| with autocast(enabled=not self.args.disable_mixed_precision): | |
| for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True): | |
| true_fgr = true_fgr.to(self.rank, non_blocking=True) | |
| true_pha = true_pha.to(self.rank, non_blocking=True) | |
| true_bgr = true_bgr.to(self.rank, non_blocking=True) | |
| true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) | |
| batch_size = true_src.size(0) | |
| pred_fgr, pred_pha = self.model(true_src)[:2] | |
| total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size | |
| total_count += batch_size | |
| avg_loss = total_loss / total_count | |
| self.log(f'Validation set average loss: {avg_loss}') | |
| self.writer.add_scalar('valid_loss', avg_loss, self.step) | |
| self.model_ddp.train() | |
| dist.barrier() | |
| def random_crop(self, *imgs): | |
| h, w = imgs[0].shape[-2:] | |
| w = random.choice(range(w // 2, w)) | |
| h = random.choice(range(h // 2, h)) | |
| results = [] | |
| for img in imgs: | |
| B, T = img.shape[:2] | |
| img = img.flatten(0, 1) | |
| img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False) | |
| img = center_crop(img, (h, w)) | |
| img = img.reshape(B, T, *img.shape[1:]) | |
| results.append(img) | |
| return results | |
| def save(self): | |
| if self.rank == 0: | |
| os.makedirs(self.args.checkpoint_dir, exist_ok=True) | |
| torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth')) | |
| self.log('Model saved') | |
| dist.barrier() | |
| def cleanup(self): | |
| dist.destroy_process_group() | |
| def log(self, msg): | |
| print(f'[GPU{self.rank}] {msg}') | |
| if __name__ == '__main__': | |
| world_size = torch.cuda.device_count() | |
| mp.spawn( | |
| Trainer, | |
| nprocs=world_size, | |
| args=(world_size,), | |
| join=True) | |