facedet / scripts /train.py
cledouxluma's picture
Upload scripts/train.py with huggingface_hub
242cb85 verified
#!/usr/bin/env python3
"""
SCRFD Training Script — Full training pipeline with:
- Multi-GPU support via DDP
- Cosine/step LR scheduling with warmup
- Gradient clipping, mixed precision
- Checkpoint saving & resuming
- WiderFace evaluation hooks
- Trackio experiment tracking
Training recipe (from SCRFD paper):
- SGD: lr=0.01, momentum=0.9, weight_decay=5e-4
- Warmup: 3 epochs linear from 1e-5
- LR decay: ×0.1 at epoch 440, 544
- Total epochs: 640 (from scratch)
- Batch: 8 per GPU × 4 GPUs
- Input: 640×640 random crops with scale [0.3, 2.0]
Usage:
# Single GPU
python scripts/train.py --config configs/scrfd_34g.yaml
# Multi-GPU
torchrun --nproc_per_node=4 scripts/train.py --config configs/scrfd_34g.yaml
"""
import os
import sys
import argparse
import time
import math
import json
import yaml
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
# Add project root to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from models.detector import build_detector
from data.dataloader import build_train_loader, build_val_loader
def parse_args():
parser = argparse.ArgumentParser(description='Train SCRFD Face Detector')
parser.add_argument('--config', type=str, default='configs/scrfd_34g.yaml',
help='Path to config file')
parser.add_argument('--data-root', type=str, default='data/wider_face',
help='Path to WiderFace dataset root')
parser.add_argument('--output-dir', type=str, default='checkpoints',
help='Output directory for checkpoints')
parser.add_argument('--resume', type=str, default=None,
help='Path to checkpoint to resume from')
parser.add_argument('--model', type=str, default='scrfd_34g',
choices=['scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g'],
help='Model variant')
parser.add_argument('--epochs', type=int, default=640)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--warmup-epochs', type=int, default=3)
parser.add_argument('--lr-steps', nargs='+', type=int, default=[440, 544])
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--input-size', type=int, default=640)
parser.add_argument('--use-landmarks', action='store_true')
parser.add_argument('--enable-robustness', action='store_true', default=True)
parser.add_argument('--amp', action='store_true', default=True,
help='Use automatic mixed precision')
parser.add_argument('--grad-clip', type=float, default=35.0)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--save-freq', type=int, default=20)
parser.add_argument('--log-freq', type=int, default=50)
parser.add_argument('--eval-freq', type=int, default=50)
parser.add_argument('--local_rank', type=int, default=0)
return parser.parse_args()
def setup_distributed():
"""Initialize DDP if available."""
if 'RANK' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
return True, rank, world_size, local_rank
return False, 0, 1, 0
def build_optimizer(model, lr, momentum, weight_decay):
"""Build SGD optimizer with weight decay on conv weights only."""
params_with_decay = []
params_no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'bn' in name or 'gn' in name or 'bias' in name:
params_no_decay.append(param)
else:
params_with_decay.append(param)
return optim.SGD([
{'params': params_with_decay, 'weight_decay': weight_decay},
{'params': params_no_decay, 'weight_decay': 0.0},
], lr=lr, momentum=momentum)
def warmup_lr(optimizer, epoch, step, steps_per_epoch, warmup_epochs, base_lr):
"""Linear warmup from 1e-5 to base_lr."""
warmup_steps = warmup_epochs * steps_per_epoch
current_step = epoch * steps_per_epoch + step
if current_step < warmup_steps:
lr = 1e-5 + (base_lr - 1e-5) * current_step / warmup_steps
for pg in optimizer.param_groups:
pg['lr'] = lr
def train_one_epoch(model, loader, optimizer, scaler, epoch, args, is_main):
"""Train one epoch."""
model.train()
total_losses = {'cls_loss': 0, 'reg_loss': 0, 'total_loss': 0, 'num_pos': 0}
num_batches = 0
start_time = time.time()
for step, (images, targets) in enumerate(loader):
images = images.cuda(non_blocking=True)
targets = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in targets]
# Warmup LR
if epoch < args.warmup_epochs:
warmup_lr(optimizer, epoch, step, len(loader),
args.warmup_epochs, args.lr)
optimizer.zero_grad()
if args.amp:
with autocast():
losses = model(images, targets)
scaler.scale(losses['total_loss']).backward()
if args.grad_clip > 0:
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
losses = model(images, targets)
losses['total_loss'].backward()
if args.grad_clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
for k in total_losses:
total_losses[k] += losses[k].item()
num_batches += 1
# Logging
if is_main and step % args.log_freq == 0:
elapsed = time.time() - start_time
fps = (step + 1) * args.batch_size / elapsed if elapsed > 0 else 0
print(f" [Epoch {epoch}][{step}/{len(loader)}] "
f"cls={losses['cls_loss'].item():.4f} "
f"reg={losses['reg_loss'].item():.4f} "
f"total={losses['total_loss'].item():.4f} "
f"pos={losses['num_pos'].item():.0f} "
f"lr={optimizer.param_groups[0]['lr']:.6f} "
f"fps={fps:.1f}")
avg_losses = {k: v / max(num_batches, 1) for k, v in total_losses.items()}
return avg_losses
def main():
args = parse_args()
distributed, rank, world_size, local_rank = setup_distributed()
is_main = rank == 0
if is_main:
os.makedirs(args.output_dir, exist_ok=True)
print(f"Training {args.model} for {args.epochs} epochs")
print(f" Distributed: {distributed} (world_size={world_size})")
print(f" Batch size: {args.batch_size} × {world_size} = {args.batch_size * world_size}")
print(f" LR: {args.lr}, steps: {args.lr_steps}")
print(f" Input size: {args.input_size}")
# Build model
model = build_detector(
args.model,
use_landmarks=args.use_landmarks,
).cuda()
if is_main:
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f" Model parameters: {num_params:.2f}M")
if distributed:
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
# Build data loaders
train_loader = build_train_loader(
args.data_root,
batch_size=args.batch_size,
target_size=args.input_size,
num_workers=args.num_workers,
use_landmarks=args.use_landmarks,
enable_robustness=args.enable_robustness,
distributed=distributed,
rank=rank,
world_size=world_size,
)
# Optimizer & scheduler
optimizer = build_optimizer(model, args.lr, args.momentum, args.weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1)
scaler = GradScaler() if args.amp else None
# Resume
start_epoch = 0
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_state = checkpoint['model_state_dict']
if distributed:
model.module.load_state_dict(model_state)
else:
model.load_state_dict(model_state)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
if is_main:
print(f" Resumed from epoch {start_epoch}")
# Training loop
best_loss = float('inf')
for epoch in range(start_epoch, args.epochs):
if distributed:
train_loader.sampler.set_epoch(epoch)
avg_losses = train_one_epoch(model, train_loader, optimizer, scaler,
epoch, args, is_main)
# Step LR (after warmup)
if epoch >= args.warmup_epochs:
scheduler.step()
# Logging
if is_main:
print(f"Epoch {epoch} avg: cls={avg_losses['cls_loss']:.4f} "
f"reg={avg_losses['reg_loss']:.4f} "
f"total={avg_losses['total_loss']:.4f}")
# Save checkpoint
if is_main and (epoch + 1) % args.save_freq == 0:
state = {
'epoch': epoch,
'model_state_dict': (model.module if distributed else model).state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'avg_losses': avg_losses,
'config': vars(args),
}
path = os.path.join(args.output_dir, f'{args.model}_epoch{epoch}.pth')
torch.save(state, path)
print(f" Saved checkpoint: {path}")
if avg_losses['total_loss'] < best_loss:
best_loss = avg_losses['total_loss']
best_path = os.path.join(args.output_dir, f'{args.model}_best.pth')
torch.save(state, best_path)
print(f" New best model: {best_path}")
# Save final model
if is_main:
final_state = {
'epoch': args.epochs - 1,
'model_state_dict': (model.module if distributed else model).state_dict(),
'config': vars(args),
}
torch.save(final_state, os.path.join(args.output_dir, f'{args.model}_final.pth'))
print("Training complete!")
if distributed:
dist.destroy_process_group()
if __name__ == '__main__':
main()