Nadine Rueegg
initial commit with code and data
753fd9a
raw
history blame
9.63 kB
# scripts/train.py --workers 12 --checkpoint project22_no3dcgloss_smaldogsilvia_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config barc_cfg_train.yaml start --model-file-hg hg_ksp_fromnewanipose_stanext_v0/checkpoint.pth.tar --model-file-3d barc_normflow_pret/checkpoint.pth.tar
import torch
import torch.backends.cudnn
import torch.nn.parallel
from tqdm import tqdm
import os
import json
import pathlib
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../', 'src'))
# from stacked_hourglass.loss import joints_mse_loss
from stacked_hourglass.loss import joints_mse_loss_onKPloc
from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
from stacked_hourglass.utils.transforms import fliplr, flip_back
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints
def do_training_step(model, optimiser, input, target, meta, data_info, target_weight=None):
assert model.training, 'model must be in training mode.'
assert len(input) == len(target), 'input and target must contain the same number of examples.'
with torch.enable_grad():
# Forward pass and loss calculation.
output = model(input)
# original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output)
# NEW:
loss = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in output)
# Backward pass and parameter update.
optimiser.zero_grad()
loss.backward()
optimiser.step()
return output[-1], loss.item()
def do_training_epoch(train_loader, model, device, data_info, optimiser, quiet=False, acc_joints=None):
losses = AverageMeter()
accuracies = AverageMeter()
# Put the model in training mode.
model.train()
iterable = enumerate(train_loader)
progress = None
if not quiet:
progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
iterable = progress
for i, (input, target, meta) in iterable:
input, target = input.to(device), target.to(device, non_blocking=True)
target_weight = meta['target_weight'].to(device, non_blocking=True)
output, loss = do_training_step(model, optimiser, input, target, meta, data_info, target_weight)
acc = accuracy(output, target, acc_joints)
# measure accuracy and record loss
losses.update(loss, input.size(0))
accuracies.update(acc[0], input.size(0))
# Show accuracy and loss as part of the progress bar.
if progress is not None:
progress.set_postfix_str('Loss: {loss:0.4f}, Acc: {acc:6.2f}'.format(
loss=losses.avg,
acc=100 * accuracies.avg
))
return losses.avg, accuracies.avg
def do_validation_step(model, input, target, meta, data_info, target_weight=None, flip=False):
# assert not model.training, 'model must be in evaluation mode.'
assert len(input) == len(target), 'input and target must contain the same number of examples.'
# Forward pass and loss calculation.
output = model(input)
# original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output)
# NEW:
loss = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in output)
# Get the heatmaps.
if flip:
# If `flip` is true, perform horizontally flipped inference as well. This should
# result in more robust predictions at the expense of additional compute.
flip_input = fliplr(input)
flip_output = model(flip_input)
flip_output = flip_output[-1].cpu()
flip_output = flip_back(flip_output.detach(), data_info.hflip_indices)
heatmaps = (output[-1].cpu() + flip_output) / 2
else:
heatmaps = output[-1].cpu()
return heatmaps, loss.item()
def do_validation_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None):
losses = AverageMeter()
accuracies = AverageMeter()
predictions = [None] * len(val_loader.dataset)
if save_imgs_path is not None:
pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
# Put the model in evaluation mode.
model.eval()
iterable = enumerate(val_loader)
progress = None
if not quiet:
progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
iterable = progress
for i, (input, target, meta) in iterable:
# Copy data to the training device (eg GPU).
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
target_weight = meta['target_weight'].to(device, non_blocking=True)
# import pdb; pdb.set_trace()
heatmaps, loss = do_validation_step(model, input, target, meta, data_info, target_weight, flip)
# Calculate PCK from the predicted heatmaps.
acc = accuracy(heatmaps, target.cpu(), acc_joints)
# Calculate locations in original image space from the predicted heatmaps.
preds = final_preds(heatmaps, meta['center'], meta['scale'], [64, 64])
# NEW for visualization: (and redundant, but for visualization)
preds_unprocessed, preds_unprocessed_maxval = get_preds_soft(heatmaps, return_maxval=True)
# preds_unprocessed, preds_unprocessed_norm, preds_unprocessed_maxval = get_preds_soft(heatmaps, return_maxval=True, norm_and_unnorm_coords=True)
# import pdb; pdb.set_trace()
ind = 0
for example_index, pose in zip(meta['index'], preds):
predictions[example_index] = pose
# NEW for visualization
if save_imgs_path is not None:
out_name = os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.png')
pred_unp = preds_unprocessed[ind, :, :]
pred_unp_maxval = preds_unprocessed_maxval[ind, :, :]
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
inp_img = input[ind, :, :, :]
# the following line (with -1) should not be needed anymore after cvpr (after bugfix01 in data preparation 08.09.2022)
# pred_unp_prep[:, :2] = pred_unp_prep[:, :2] - 1
# save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, threshold=0.1, print_scores=True) # here we have default ratio_in_out=4.
# NEW: 08.09.2022 after bugfix01
# import pdb; pdb.set_trace()
pred_unp_prep[:, :2] = pred_unp_prep[:, :2] * 4
if 'name' in meta.keys(): # we do this for the stanext set
name = meta['name'][ind]
out_path_keyp_img = os.path.join(os.path.dirname(out_name), name)
out_path_json = os.path.join(os.path.dirname(out_name), name).replace('_vis', '_json').replace('.jpg', '.json')
if not os.path.exists(os.path.dirname(out_path_json)):
os.makedirs(os.path.dirname(out_path_json))
if not os.path.exists(os.path.dirname(out_path_keyp_img)):
os.makedirs(os.path.dirname(out_path_keyp_img))
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path_keyp_img, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3
out_name_json = out_path_json # os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.json')
res_dict = {
'pred_joints_256': list(pred_unp_prep.cpu().numpy().astype(float).reshape((-1))),
'center': list(meta['center'][ind, :].cpu().numpy().astype(float).reshape((-1))),
'scale': meta['scale'][ind].item()}
with open(out_name_json, 'w') as outfile: json.dump(res_dict, outfile)
else:
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3
'''# animalpose_hg8_v0 (did forget to subtract 1 in dataset)
pred_unp_prep[:, :2] = pred_unp_prep[:, :2] * 4 ############ Why is this necessary???
pred_unp_prep[:, :2] = pred_unp_prep[:, :2] - 1
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3
out_name_json = os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.json')
res_dict = {
'pred_joints_256': list(pred_unp_prep.cpu().numpy().astype(float).reshape((-1))),
'center': list(meta['center'][ind, :].cpu().numpy().astype(float).reshape((-1))),
'scale': meta['scale'][ind].item()}
with open(out_name_json, 'w') as outfile: json.dump(res_dict, outfile)'''
ind += 1
# Record accuracy and loss for this batch.
losses.update(loss, input.size(0))
accuracies.update(acc[0].item(), input.size(0))
# Show accuracy and loss as part of the progress bar.
if progress is not None:
progress.set_postfix_str('Loss: {loss:0.4f}, Acc: {acc:6.2f}'.format(
loss=losses.avg,
acc=100 * accuracies.avg
))
predictions = torch.stack(predictions, dim=0)
return losses.avg, accuracies.avg, predictions