|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
from torch import optim |
|
from torchvision import transforms |
|
from torch.optim import lr_scheduler |
|
|
|
|
|
|
|
|
|
from generate_c3d_model import generate_model |
|
from train import train_epoch |
|
|
|
|
|
|
|
|
|
from datasets.nv import NV |
|
|
|
|
|
|
|
|
|
from utils import * |
|
from target_transforms import * |
|
|
|
|
|
|
|
|
|
from logger.logger import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
logger.info(f"run") |
|
torch.manual_seed(1) |
|
arch = '{}'.format('c3d') |
|
n_epochs = 35 |
|
n_classes = 26 |
|
sample_size = 112 |
|
ft_portion = "last_layer" |
|
downsample = 2 |
|
scale_step = 0.84089641525 |
|
scales = [1.0] |
|
for i in range(1, 5): |
|
scales.append(scales[-1] * scale_step) |
|
model, parameters = generate_model(n_classes, sample_size, ft_portion) |
|
criterion = nn.CrossEntropyLoss() |
|
criterion = criterion.cuda() |
|
spatial_transform = transforms.Compose([ |
|
]) |
|
temporal_transform = transforms.Compose([ |
|
transforms.ToTensor() |
|
]) |
|
target_transform = ClassLabel() |
|
optimizer = optim.SGD( |
|
parameters, |
|
lr=0.1, |
|
momentum=0.9, |
|
dampening=0.9, |
|
weight_decay=1e-3, |
|
nesterov=False) |
|
|
|
scheduler = lr_scheduler.ReduceLROnPlateau( |
|
optimizer, 'min', patience=10) |
|
|
|
training_data = NV( |
|
'./nvGesture_v1.1/nvGesture_v1', |
|
'./annotation_nvGesture_v1/nvall_but_None.json', |
|
'training', |
|
spatial_transform=spatial_transform, |
|
temporal_transform=temporal_transform, |
|
target_transform=target_transform, |
|
modality="RGB-D") |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
training_data, |
|
batch_size=80, |
|
shuffle=True, |
|
num_workers=12, |
|
pin_memory=True) |
|
|
|
best_prec1 = 0 |
|
for i in range(1, n_epochs + 1): |
|
|
|
torch.cuda.empty_cache() |
|
adjust_learning_rate(optimizer, i) |
|
train_epoch(i, train_loader, model, criterion, optimizer) |
|
state = { |
|
'epoch': i, |
|
'arch': arch, |
|
'state_dict': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'best_prec1': best_prec1 |
|
} |
|
save_checkpoint(state, False) |
|
|