Spaces:
Paused
Paused
import os | |
import time | |
import torch | |
import torch.nn as nn | |
from torch import optim | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import default_collate | |
import numpy as np | |
from datetime import datetime | |
import torch.nn.functional as F | |
from datasets.crowd import Crowd_TC, Crowd_UL_TC | |
from network import pvt_cls as TCN | |
from losses.multi_con_loss import MultiConLoss | |
from utils.pytorch_utils import Save_Handle, AverageMeter | |
import utils.log_utils as log_utils | |
import argparse | |
from losses.rank_loss import RankLoss | |
from losses import ramps | |
from losses.ot_loss import OT_Loss | |
from losses.consistency_loss import * | |
parser = argparse.ArgumentParser(description='Train') | |
parser.add_argument('--data-dir', default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='data path') | |
parser.add_argument('--dataset', default='TC') | |
parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate') | |
parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay') | |
parser.add_argument('--resume', default='', type=str, help='the path of resume training model') | |
parser.add_argument('--max-epoch', type=int, default=4000, help='max training epoch') | |
parser.add_argument('--val-epoch', type=int, default=1, help='the num of steps to log training information') | |
parser.add_argument('--val-start', type=int, default=0, help='the epoch start to val') | |
parser.add_argument('--batch-size', type=int, default=16, help='train batch size') | |
parser.add_argument('--batch-size-ul', type=int, default=16, help='train batch size') | |
parser.add_argument('--device', default='0', help='assign device') | |
parser.add_argument('--num-workers', type=int, default=0, help='the num of training process') | |
parser.add_argument('--crop-size', type=int, default= 256, help='the crop size of the train image') | |
parser.add_argument('--rl', type=float, default=1, help='entropy regularization in sinkhorn') | |
parser.add_argument('--reg', type=float, default=1, help='entropy regularization in sinkhorn') | |
parser.add_argument('--ot', type=float, default=0.1, help='entropy regularization in sinkhorn') | |
parser.add_argument('--tv', type=float, default=0.01, help='entropy regularization in sinkhorn') | |
parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations') | |
parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance') | |
parser.add_argument('--run-name', default='Treeformer_test', help='run name for wandb interface/logging') | |
parser.add_argument('--consistency', type=int, default=1, help='whether to norm cood when computing distance') | |
args = parser.parse_args() | |
def train_collate(batch): | |
transposed_batch = list(zip(*batch)) | |
images = torch.stack(transposed_batch[0], 0) | |
gauss = torch.stack(transposed_batch[1], 0) | |
points = transposed_batch[2] | |
gt_discretes = torch.stack(transposed_batch[3], 0) | |
return images, gauss, points, gt_discretes | |
def train_collate_UL(batch): | |
transposed_batch = list(zip(*batch)) | |
images = torch.stack(transposed_batch[0], 0) | |
return images | |
def get_current_consistency_weight(epoch): | |
# Consistency ramp-up from https://arxiv.org/abs/1610.02242 | |
return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_ramp) | |
class Trainer(object): | |
def __init__(self, args): | |
self.args = args | |
def setup(self): | |
args = self.args | |
sub_dir = ( | |
"SEMI/{}_12-1-input-{}_reg-{}_nIter-{}_normCood-{}".format( | |
args.run_name,args.crop_size,args.reg, | |
args.num_of_iter_in_ot,args.norm_cood)) | |
self.save_dir = os.path.join("/scratch/users/k2254235","ckpts", sub_dir) | |
if not os.path.exists(self.save_dir): | |
os.makedirs(self.save_dir) | |
time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S") | |
self.logger = log_utils.get_logger( | |
os.path.join(self.save_dir, "train-{:s}.log".format(time_str))) | |
log_utils.print_config(vars(args), self.logger) | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
self.device_count = torch.cuda.device_count() | |
self.logger.info("using {} gpus".format(self.device_count)) | |
else: | |
raise Exception("gpu is not available") | |
downsample_ratio = 4 | |
self.datasets = {"train": Crowd_TC(os.path.join(args.data_dir, "train_data"), args.crop_size, | |
downsample_ratio, "train"), "val": Crowd_TC(os.path.join(args.data_dir, "valid_data"), | |
args.crop_size, downsample_ratio, "val")} | |
self.datasets_ul = { "train_ul": Crowd_UL_TC(os.path.join(args.data_dir, "train_data_ul"), | |
args.crop_size, downsample_ratio, "train_ul")} | |
self.dataloaders = { | |
x: DataLoader(self.datasets[x], | |
collate_fn=(train_collate if x == "train" else default_collate), | |
batch_size=(args.batch_size if x == "train" else 1), | |
shuffle=(True if x == "train" else False), | |
num_workers=args.num_workers * self.device_count, | |
pin_memory=(True if x == "train" else False)) | |
for x in ["train", "val"]} | |
self.dataloaders_ul = { | |
x: DataLoader(self.datasets_ul[x], | |
collate_fn=(train_collate_UL ), | |
batch_size=(args.batch_size_ul), | |
shuffle=(True), | |
num_workers=args.num_workers * self.device_count, | |
pin_memory=(True if x == "train" else False)) | |
for x in ["train_ul"]} | |
self.model = TCN.pvt_treeformer(pretrained=False) | |
self.model.to(self.device) | |
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
self.start_epoch = 0 | |
if args.resume: | |
self.logger.info("loading pretrained model from " + args.resume) | |
suf = args.resume.rsplit(".", 1)[-1] | |
if suf == "tar": | |
checkpoint = torch.load(args.resume, self.device) | |
self.model.load_state_dict(checkpoint["model_state_dict"]) | |
self.optimizer.load_state_dict( | |
checkpoint["optimizer_state_dict"]) | |
self.start_epoch = checkpoint["epoch"] + 1 | |
elif suf == "pth": | |
self.model.load_state_dict( | |
torch.load(args.resume, self.device)) | |
else: | |
self.logger.info("random initialization") | |
self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood, | |
self.device, args.num_of_iter_in_ot, args.reg) | |
self.tvloss = nn.L1Loss(reduction="none").to(self.device) | |
self.mse = nn.MSELoss().to(self.device) | |
self.mae = nn.L1Loss().to(self.device) | |
self.save_list = Save_Handle(max_num=1) | |
self.best_mae = np.inf | |
self.best_mse = np.inf | |
self.rankloss = RankLoss().to(self.device) | |
self.kl_distance = nn.KLDivLoss(reduction='none') | |
self.multiconloss = MultiConLoss().to(self.device) | |
def train(self): | |
"""training process""" | |
args = self.args | |
for epoch in range(self.start_epoch, args.max_epoch + 1): | |
self.logger.info("-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5) | |
self.epoch = epoch | |
self.train_epoch() | |
if epoch % args.val_epoch == 0 and epoch >= args.val_start: | |
self.val_epoch() | |
def train_epoch(self): | |
epoch_ot_loss = AverageMeter() | |
epoch_ot_obj_value = AverageMeter() | |
epoch_wd = AverageMeter() | |
epoch_tv_loss = AverageMeter() | |
epoch_count_loss = AverageMeter() | |
epoch_count_consistency_l = AverageMeter() | |
epoch_count_consistency_ul = AverageMeter() | |
epoch_loss = AverageMeter() | |
epoch_mae = AverageMeter() | |
epoch_mse = AverageMeter() | |
epoch_start = time.time() | |
epoch_rank_loss = AverageMeter() | |
epoch_consistensy_loss = AverageMeter() | |
self.model.train() # Set model to training mode | |
for step, (inputs, gausss, points, gt_discrete) in enumerate(self.dataloaders["train"]): | |
inputs = inputs.to(self.device) | |
gausss = gausss.to(self.device) | |
gd_count = np.array([len(p) for p in points], dtype=np.float32) | |
points = [p.to(self.device) for p in points] | |
gt_discrete = gt_discrete.to(self.device) | |
N = inputs.size(0) | |
for st, unlabel_data in enumerate(self.dataloaders_ul["train_ul"]): | |
inputs_ul = unlabel_data.to(self.device) | |
break | |
with torch.set_grad_enabled(True): | |
outputs_L, outputs_UL, outputs_normed, CLS_L, CLS_UL = self.model(inputs, inputs_ul) | |
outputs_L = outputs_L[0] | |
with torch.set_grad_enabled(False): | |
preds_UL = (outputs_UL[0][0] + outputs_UL[1][0] + outputs_UL[2][0])/3 | |
# Compute counting loss. | |
count_loss = self.mae(outputs_L.sum(1).sum(1).sum(1),torch.from_numpy(gd_count).float().to(self.device))*self.args.reg | |
# Compute OT loss. | |
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L, points) | |
ot_loss = ot_loss* self.args.ot | |
ot_obj_value = ot_obj_value* self.args.ot | |
gd_count_tensor = (torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)) | |
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6) | |
tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)* | |
torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv | |
epoch_ot_loss.update(ot_loss.item(), N) | |
epoch_ot_obj_value.update(ot_obj_value.item(), N) | |
epoch_wd.update(wd, N) | |
epoch_count_loss.update(count_loss.item(), N) | |
epoch_tv_loss.update(tv_loss.item(), N) | |
# Compute ranking loss. | |
rank_loss = self.rankloss(outputs_UL)*self.args.rl | |
epoch_rank_loss.update(rank_loss.item(), N) | |
# Compute multi level consistancy loss | |
consistency_loss = args.consistency * self.multiconloss(outputs_UL) | |
epoch_consistensy_loss.update(consistency_loss.item(), N) | |
# Compute consistency count | |
Con_cls_UL = (CLS_UL[0] + CLS_UL[1] + CLS_UL[2])/3 | |
Con_cls_L = torch.from_numpy(gd_count).float().to(self.device) | |
count_loss_l = self.mae(torch.stack((CLS_L[0],CLS_L[1],CLS_L[2])), torch.stack((Con_cls_L, Con_cls_L, Con_cls_L))) | |
count_loss_ul = self.mae(torch.stack((CLS_UL[0],CLS_UL[1],CLS_UL[2])), torch.stack((Con_cls_UL, Con_cls_UL, Con_cls_UL))) | |
epoch_count_consistency_l.update(count_loss_l.item(), N) | |
epoch_count_consistency_ul.update(count_loss_ul.item(), N) | |
loss = count_loss + ot_loss + tv_loss + rank_loss + count_loss_l + count_loss_ul + consistency_loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
pred_count = (torch.sum(outputs_L.view(N, -1), | |
dim=1).detach().cpu().numpy()) | |
pred_err = pred_count - gd_count | |
epoch_loss.update(loss.item(), N) | |
epoch_mse.update(np.mean(pred_err * pred_err), N) | |
epoch_mae.update(np.mean(abs(pred_err)), N) | |
self.logger.info( | |
"Epoch {} Train, Loss: {:.2f}, Count Loss: {:.2f}, OT Loss: {:.2e}, TV Loss: {:.2e}, Rank Loss: {:.2f}," | |
"Consistensy Loss: {:.2f}, MSE: {:.2f}, MAE: {:.2f},LC Loss: {:.2f}, ULC Loss: {:.2f}, Cost {:.1f} sec".format( | |
self.epoch, epoch_loss.get_avg(), epoch_count_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_tv_loss.get_avg(), epoch_rank_loss.get_avg(), | |
epoch_consistensy_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), epoch_count_consistency_l.get_avg(), | |
epoch_count_consistency_ul.get_avg(), time.time() - epoch_start)) | |
model_state_dic = self.model.state_dict() | |
save_path = os.path.join(self.save_dir, "{}_ckpt.tar".format(self.epoch)) | |
torch.save({"epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(), | |
"model_state_dict": model_state_dic}, save_path) | |
self.save_list.append(save_path) | |
def val_epoch(self): | |
args = self.args | |
epoch_start = time.time() | |
self.model.eval() # Set model to evaluate mode | |
epoch_res = [] | |
for inputs, count, name, gauss_im in self.dataloaders["val"]: | |
with torch.no_grad(): | |
inputs = inputs.to(self.device) | |
crop_imgs, crop_masks = [], [] | |
b, c, h, w = inputs.size() | |
rh, rw = args.crop_size, args.crop_size | |
for i in range(0, h, rh): | |
gis, gie = max(min(h - rh, i), 0), min(h, i + rh) | |
for j in range(0, w, rw): | |
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) | |
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) | |
mask = torch.zeros([b, 1, h, w]).to(self.device) | |
mask[:, :, gis:gie, gjs:gje].fill_(1.0) | |
crop_masks.append(mask) | |
crop_imgs, crop_masks = map( | |
lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks)) | |
crop_preds = [] | |
nz, bz = crop_imgs.size(0), args.batch_size | |
for i in range(0, nz, bz): | |
gs, gt = i, min(nz, i + bz) | |
crop_pred, _ = self.model(crop_imgs[gs:gt]) | |
crop_pred = crop_pred[0] | |
_, _, h1, w1 = crop_pred.size() | |
crop_pred = (F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), | |
mode="bilinear", align_corners=True) / 16 ) | |
crop_preds.append(crop_pred) | |
crop_preds = torch.cat(crop_preds, dim=0) | |
# splice them to the original size | |
idx = 0 | |
pred_map = torch.zeros([b, 1, h, w]).to(self.device) | |
for i in range(0, h, rh): | |
gis, gie = max(min(h - rh, i), 0), min(h, i + rh) | |
for j in range(0, w, rw): | |
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) | |
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] | |
idx += 1 | |
# for the overlapping area, compute average value | |
mask = crop_masks.sum(dim=0).unsqueeze(0) | |
outputs = pred_map / mask | |
res = count[0].item() - torch.sum(outputs).item() | |
epoch_res.append(res) | |
epoch_res = np.array(epoch_res) | |
mse = np.sqrt(np.mean(np.square(epoch_res))) | |
mae = np.mean(np.abs(epoch_res)) | |
self.logger.info("Epoch {} Val, MSE: {:.2f}, MAE: {:.2f}, Cost {:.1f} sec".format( | |
self.epoch, mse, mae, time.time() - epoch_start )) | |
model_state_dic = self.model.state_dict() | |
print("Comaprison", mae, self.best_mae) | |
if mae < self.best_mae: | |
self.best_mse = mse | |
self.best_mae = mae | |
self.logger.info( | |
"save best mse {:.2f} mae {:.2f} model epoch {}".format( | |
self.best_mse, self.best_mae, self.epoch)) | |
print("Saving best model at {} epoch".format(self.epoch)) | |
model_path = os.path.join( | |
self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format( | |
self.best_mae, self.epoch)) | |
torch.save(model_state_dic, model_path) | |
if __name__ == "__main__": | |
import torch | |
torch.backends.cudnn.benchmark = True | |
trainer = Trainer(args) | |
trainer.setup() | |
trainer.train() | |