deepfake / training /pipelines /train_classifier.py
thecho7's picture
LFS dat
c426e13
raw
history blame
15.6 kB
import argparse
import json
import os
from collections import defaultdict
from sklearn.metrics import log_loss
from torch import topk
import sys
print('@@@@@@@@@@@@@@@@@@')
sys.path.append('..')
from training import losses
from training.datasets.classifier_dataset import DeepFakeClassifierDataset
from training.losses import WeightedLosses
from training.tools.config import load_config
from training.tools.utils import create_optimizer, AverageMeter
from training.transforms.albu import IsotropicResize
from training.zoo import classifiers
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import cv2
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
import numpy as np
from albumentations import Compose, RandomBrightnessContrast, \
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
from apex.parallel import DistributedDataParallel, convert_syncbn_model
from tensorboardX import SummaryWriter
from apex import amp
import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.distributed as dist
torch.backends.cudnn.benchmark = True
def create_train_transforms(size=300):
return Compose([
ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
GaussNoise(p=0.1),
GaussianBlur(blur_limit=3, p=0.05),
HorizontalFlip(),
OneOf([
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
], p=1),
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
ToGray(p=0.2),
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
]
)
def create_val_transforms(size=300):
return Compose([
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
])
def main():
parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
arg = parser.add_argument
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
arg('--workers', type=int, default=6, help='number of cpu threads to use')
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
arg('--output-dir', type=str, default='weights/')
arg('--resume', type=str, default='')
arg('--fold', type=int, default=0)
arg('--prefix', type=str, default='classifier_')
arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
arg('--folds-csv', type=str, default='folds.csv')
arg('--crops-dir', type=str, default='crops')
arg('--label-smoothing', type=float, default=0.01)
arg('--logdir', type=str, default='logs')
arg('--zero-score', action='store_true', default=False)
arg('--from-zero', action='store_true', default=False)
arg('--distributed', action='store_true', default=False)
arg('--freeze-epochs', type=int, default=0)
arg("--local_rank", default=0, type=int)
arg("--seed", default=777, type=int)
arg("--padding-part", default=3, type=int)
arg("--opt-level", default='O1', type=str)
arg("--test_every", type=int, default=1)
arg("--no-oversample", action="store_true")
arg("--no-hardcore", action="store_true")
arg("--only-changed-frames", action="store_true")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
else:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
cudnn.benchmark = True
conf = load_config(args.config)
model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
model = model.cuda()
if args.distributed:
model = convert_syncbn_model(model)
ohem = conf.get("ohem_samples", None)
reduction = "mean"
if ohem:
reduction = "none"
loss_fn = []
weights = []
for loss_name, weight in conf["losses"].items():
loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
weights.append(weight)
loss = WeightedLosses(loss_fn, weights)
loss_functions = {"classifier_loss": loss}
optimizer, scheduler = create_optimizer(conf['optimizer'], model)
bce_best = 100
start_epoch = 0
batch_size = conf['optimizer']['batch_size']
data_train = DeepFakeClassifierDataset(mode="train",
oversample_real=not args.no_oversample,
fold=args.fold,
padding_part=args.padding_part,
hardcore=not args.no_hardcore,
crops_dir=args.crops_dir,
data_path=args.data_dir,
label_smoothing=args.label_smoothing,
folds_csv=args.folds_csv,
transforms=create_train_transforms(conf["size"]),
normalize=conf.get("normalize", None))
data_val = DeepFakeClassifierDataset(mode="val",
fold=args.fold,
padding_part=args.padding_part,
crops_dir=args.crops_dir,
data_path=args.data_dir,
folds_csv=args.folds_csv,
transforms=create_val_transforms(conf["size"]),
normalize=conf.get("normalize", None))
val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
pin_memory=False)
os.makedirs(args.logdir, exist_ok=True)
summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
state_dict = checkpoint['state_dict']
state_dict = {k[7:]: w for k, w in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
if not args.from_zero:
start_epoch = checkpoint['epoch']
if not args.zero_score:
bce_best = checkpoint.get('bce_best', 0)
print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
.format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
if args.from_zero:
start_epoch = 0
current_epoch = start_epoch
if conf['fp16']:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
loss_scale='dynamic')
snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
if args.distributed:
model = DistributedDataParallel(model, delay_allreduce=True)
else:
model = DataParallel(model).cuda()
data_val.reset(1, args.seed)
max_epochs = conf['optimizer']['schedule']['epochs']
for epoch in range(start_epoch, max_epochs):
data_train.reset(epoch, args.seed)
train_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
train_sampler.set_epoch(epoch)
if epoch < args.freeze_epochs:
print("Freezing encoder!!!")
model.module.encoder.eval()
for p in model.module.encoder.parameters():
p.requires_grad = False
else:
model.module.encoder.train()
for p in model.module.encoder.parameters():
p.requires_grad = True
train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
drop_last=True)
train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
args.local_rank, args.only_changed_frames)
model = model.eval()
if args.local_rank == 0:
torch.save({
'epoch': current_epoch + 1,
'state_dict': model.state_dict(),
'bce_best': bce_best,
}, args.output_dir + '/' + snapshot_name + "_last")
torch.save({
'epoch': current_epoch + 1,
'state_dict': model.state_dict(),
'bce_best': bce_best,
}, args.output_dir + snapshot_name + "_{}".format(current_epoch))
if (epoch + 1) % args.test_every == 0:
bce_best = evaluate_val(args, val_data_loader, bce_best, model,
snapshot_name=snapshot_name,
current_epoch=current_epoch,
summary_writer=summary_writer)
current_epoch += 1
def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
print("Test phase")
model = model.eval()
bce, probs, targets = validate(model, data_loader=data_val)
if args.local_rank == 0:
summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
if bce < bce_best:
print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
if args.output_dir is not None:
torch.save({
'epoch': current_epoch + 1,
'state_dict': model.state_dict(),
'bce_best': bce,
}, args.output_dir + snapshot_name + "_best_dice")
bce_best = bce
with open("predictions_{}.json".format(args.fold), "w") as f:
json.dump({"probs": probs, "targets": targets}, f)
torch.save({
'epoch': current_epoch + 1,
'state_dict': model.state_dict(),
'bce_best': bce_best,
}, args.output_dir + snapshot_name + "_last")
print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
return bce_best
def validate(net, data_loader, prefix=""):
probs = defaultdict(list)
targets = defaultdict(list)
with torch.no_grad():
for sample in tqdm(data_loader):
imgs = sample["image"].cuda()
img_names = sample["img_name"]
labels = sample["labels"].cuda().float()
out = net(imgs)
labels = labels.cpu().numpy()
preds = torch.sigmoid(out).cpu().numpy()
for i in range(out.shape[0]):
video, img_id = img_names[i].split("/")
probs[video].append(preds[i].tolist())
targets[video].append(labels[i].tolist())
data_x = []
data_y = []
for vid, score in probs.items():
score = np.array(score)
lbl = targets[vid]
score = np.mean(score)
lbl = np.mean(lbl)
data_x.append(score)
data_y.append(lbl)
y = np.array(data_y)
x = np.array(data_x)
fake_idx = y > 0.1
real_idx = y < 0.1
fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
print("{}fake_loss".format(prefix), fake_loss)
print("{}real_loss".format(prefix), real_loss)
return (fake_loss + real_loss) / 2, probs, targets
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
local_rank, only_valid):
losses = AverageMeter()
fake_losses = AverageMeter()
real_losses = AverageMeter()
max_iters = conf["batches_per_epoch"]
print("training epoch {}".format(current_epoch))
model.train()
pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
if conf["optimizer"]["schedule"]["mode"] == "epoch":
scheduler.step(current_epoch)
for i, sample in pbar:
imgs = sample["image"].cuda()
labels = sample["labels"].cuda().float()
out_labels = model(imgs)
if only_valid:
valid_idx = sample["valid"].cuda().float() > 0
out_labels = out_labels[valid_idx]
labels = labels[valid_idx]
if labels.size(0) == 0:
continue
fake_loss = 0
real_loss = 0
fake_idx = labels > 0.5
real_idx = labels <= 0.5
ohem = conf.get("ohem_samples", None)
if torch.sum(fake_idx * 1) > 0:
fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
if torch.sum(real_idx * 1) > 0:
real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
if ohem:
fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
loss = (fake_loss + real_loss) / 2
losses.update(loss.item(), imgs.size(0))
fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
optimizer.zero_grad()
pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
"fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
if conf['fp16']:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
optimizer.step()
torch.cuda.synchronize()
if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
scheduler.step(i + current_epoch * max_iters)
if i == max_iters - 1:
break
pbar.close()
if local_rank == 0:
for idx, param_group in enumerate(optimizer.param_groups):
lr = param_group['lr']
summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
if __name__ == '__main__':
main()