|
import os |
|
import numpy as np |
|
import time |
|
import sys |
|
import argparse |
|
import errno |
|
from collections import OrderedDict |
|
import tensorboardX |
|
from tqdm import tqdm |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import StepLR |
|
from torch.utils.data import DataLoader |
|
|
|
from lib.utils.tools import * |
|
from lib.utils.learning import * |
|
from lib.model.loss import * |
|
from lib.data.dataset_action import NTURGBD, NTURGBD1Shot |
|
from lib.model.model_action import ActionNet |
|
|
|
from lib.model.loss_supcon import SupConLoss |
|
from pytorch_metric_learning import samplers |
|
|
|
random.seed(0) |
|
np.random.seed(0) |
|
torch.manual_seed(0) |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") |
|
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') |
|
parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') |
|
parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') |
|
parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') |
|
parser.add_argument('-freq', '--print_freq', default=100) |
|
parser.add_argument('-ms', '--selection', default='best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') |
|
opts = parser.parse_args() |
|
return opts |
|
|
|
def extract_feats(dataloader_x, model): |
|
all_feats = [] |
|
all_gts = [] |
|
with torch.no_grad(): |
|
for idx, (batch_input, batch_gt) in tqdm(enumerate(dataloader_x)): |
|
if torch.cuda.is_available(): |
|
batch_input = batch_input.cuda() |
|
feat = model(batch_input) |
|
all_feats.append(feat) |
|
all_gts.append(batch_gt) |
|
all_feats = torch.cat(all_feats) |
|
all_gts = torch.cat(all_gts) |
|
return all_feats, all_gts |
|
|
|
def validate(anchor_loader, test_loader, model): |
|
train_feats, train_labels = extract_feats(anchor_loader, model) |
|
test_feats, test_labels = extract_feats(test_loader, model) |
|
M = len(train_feats) |
|
N = len(test_feats) |
|
train_feats = train_feats.unsqueeze(1) |
|
test_feats = test_feats.unsqueeze(0) |
|
dis = F.cosine_similarity(train_feats, test_feats, dim=-1) |
|
pred = train_labels[torch.argmax(dis, dim=0)] |
|
assert len(pred)==len(test_labels) |
|
acc = sum(pred==test_labels) / len(pred) |
|
return acc |
|
|
|
def train_with_config(args, opts): |
|
print(args) |
|
try: |
|
os.makedirs(opts.checkpoint) |
|
except OSError as e: |
|
if e.errno != errno.EEXIST: |
|
raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) |
|
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) |
|
model_backbone = load_backbone(args) |
|
if args.finetune: |
|
if opts.resume or opts.evaluate: |
|
pass |
|
else: |
|
chk_filename = os.path.join(opts.pretrained, "best_epoch.bin") |
|
print('Loading backbone', chk_filename) |
|
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) |
|
new_state_dict = OrderedDict() |
|
for k, v in checkpoint['model_pos'].items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
model_backbone.load_state_dict(new_state_dict, strict=True) |
|
if args.partial_train: |
|
model_backbone = partial_train_layers(model_backbone, args.partial_train) |
|
model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints) |
|
criterion = SupConLoss(temperature=args.temp) |
|
|
|
if torch.cuda.is_available(): |
|
model = nn.DataParallel(model) |
|
model = model.cuda() |
|
criterion = criterion.cuda() |
|
|
|
chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") |
|
if os.path.exists(chk_filename): |
|
opts.resume = chk_filename |
|
if opts.resume or opts.evaluate: |
|
chk_filename = opts.evaluate if opts.evaluate else opts.resume |
|
print('Loading checkpoint', chk_filename) |
|
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) |
|
model.load_state_dict(checkpoint['model'], strict=True) |
|
|
|
best_acc = 0 |
|
model_params = 0 |
|
for parameter in model.parameters(): |
|
model_params = model_params + parameter.numel() |
|
print('INFO: Trainable parameter count:', model_params) |
|
print('Loading dataset...') |
|
|
|
anchorloader_params = { |
|
'batch_size': args.batch_size, |
|
'shuffle': False, |
|
'num_workers': 8, |
|
'pin_memory': True, |
|
'prefetch_factor': 4, |
|
'persistent_workers': True |
|
} |
|
|
|
testloader_params = { |
|
'batch_size': args.batch_size, |
|
'shuffle': False, |
|
'num_workers': 8, |
|
'pin_memory': True, |
|
'prefetch_factor': 4, |
|
'persistent_workers': True |
|
} |
|
data_path_1shot = 'data/action/ntu120_hrnet_oneshot.pkl' |
|
ntu60_1shot_anchor = NTURGBD(data_path=data_path_1shot, data_split='oneshot_train', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) |
|
ntu60_1shot_test = NTURGBD(data_path=data_path_1shot, data_split='oneshot_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) |
|
anchor_loader = DataLoader(ntu60_1shot_anchor, **anchorloader_params) |
|
test_loader = DataLoader(ntu60_1shot_test, **testloader_params) |
|
|
|
if not opts.evaluate: |
|
|
|
data_path = 'data/action/ntu120_hrnet.pkl' |
|
ntu120_1shot_train = NTURGBD1Shot(data_path=data_path, data_split='', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train, check_split=False) |
|
sampler = samplers.MPerClassSampler(ntu120_1shot_train.labels, m=args.n_views, batch_size=args.batch_size, length_before_new_iter=len(ntu120_1shot_train)) |
|
trainloader_params = { |
|
'batch_size': args.batch_size, |
|
'shuffle': False, |
|
'num_workers': 8, |
|
'pin_memory': True, |
|
'prefetch_factor': 4, |
|
'persistent_workers': True, |
|
'sampler': sampler |
|
} |
|
train_loader = DataLoader(ntu120_1shot_train, **trainloader_params) |
|
optimizer = optim.AdamW( |
|
[ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, |
|
{"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, |
|
], lr=args.lr_backbone, |
|
weight_decay=args.weight_decay |
|
) |
|
scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) |
|
st = 0 |
|
print('INFO: Training on {} batches'.format(len(train_loader))) |
|
if opts.resume: |
|
st = checkpoint['epoch'] |
|
if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
else: |
|
print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') |
|
|
|
lr = checkpoint['lr'] |
|
if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None: |
|
best_acc = checkpoint['best_acc'] |
|
|
|
|
|
for epoch in range(st, args.epochs): |
|
print('Training epoch %d.' % epoch) |
|
losses_train = AverageMeter() |
|
batch_time = AverageMeter() |
|
data_time = AverageMeter() |
|
|
|
model.train() |
|
end = time.time() |
|
|
|
for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): |
|
data_time.update(time.time() - end) |
|
batch_size = len(batch_input) |
|
if torch.cuda.is_available(): |
|
batch_gt = batch_gt.cuda() |
|
batch_input = batch_input.cuda() |
|
feat = model(batch_input) |
|
feat = feat.reshape(batch_size, -1, args.hidden_dim) |
|
optimizer.zero_grad() |
|
loss_train = criterion(feat, batch_gt) |
|
losses_train.update(loss_train.item(), batch_size) |
|
loss_train.backward() |
|
optimizer.step() |
|
batch_time.update(time.time() - end) |
|
end = time.time() |
|
if (idx + 1) % opts.print_freq == 0: |
|
print('Train: [{0}][{1}/{2}]\t' |
|
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
|
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' |
|
'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( |
|
epoch, idx + 1, len(train_loader), batch_time=batch_time, |
|
data_time=data_time, loss=losses_train)) |
|
sys.stdout.flush() |
|
test_top1 = validate(anchor_loader, test_loader, model) |
|
train_writer.add_scalar('train_loss_supcon', losses_train.avg, epoch + 1) |
|
train_writer.add_scalar('test_top1', test_top1, epoch + 1) |
|
scheduler.step() |
|
|
|
chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') |
|
print('Saving checkpoint to', chk_path) |
|
torch.save({ |
|
'epoch': epoch+1, |
|
'lr': scheduler.get_last_lr(), |
|
'optimizer': optimizer.state_dict(), |
|
'model': model.state_dict(), |
|
'best_acc' : best_acc |
|
}, chk_path) |
|
|
|
|
|
best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) |
|
if test_top1 > best_acc: |
|
best_acc = test_top1 |
|
print("save best checkpoint") |
|
torch.save({ |
|
'epoch': epoch+1, |
|
'lr': scheduler.get_last_lr(), |
|
'optimizer': optimizer.state_dict(), |
|
'model': model.state_dict(), |
|
'best_acc' : best_acc |
|
}, best_chk_path) |
|
if opts.evaluate: |
|
test_top1 = validate(anchor_loader, test_loader, model) |
|
print(test_top1) |
|
if __name__ == "__main__": |
|
opts = parse_args() |
|
args = get_config(opts.config) |
|
train_with_config(args, opts) |
|
|