from __future__ import print_function import argparse import inspect import os import pdb import pickle import random import re import shutil import time from collections import * import ipdb import numpy as np # torch import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import yaml from einops import rearrange, reduce, repeat from evaluation.classificationMAP import getClassificationMAP as cmAP from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP from feeders.tools import collate_with_padding_multi_joint from model.losses import cross_entropy_loss, mvl_loss from sklearn.metrics import f1_score # Custom from tensorboardX import SummaryWriter from torch.autograd import Variable from torch.optim.lr_scheduler import _LRScheduler from tqdm import tqdm from utils.logger import Logger # seed = 0 # random.seed(seed) # np.random.seed(seed) # torch.manual_seed(seed) # torch.cuda.manual_seed_all(seed) # torch.use_deterministic_algorithms(True) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False def init_seed(seed): torch.cuda.manual_seed_all(seed) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_parser(): # parameter priority: command line > config > default parser = argparse.ArgumentParser( description="Spatial Temporal Graph Convolution Network" ) parser.add_argument( "--work-dir", default="./work_dir/temp", help="the work folder for storing results", ) parser.add_argument("-model_saved_name", default="") parser.add_argument( "--config", default="./config/nturgbd-cross-view/test_bone.yaml", help="path to the configuration file", ) # processor parser.add_argument("--phase", default="train", help="must be train or test") # visulize and debug parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch") parser.add_argument( "--log-interval", type=int, default=100, help="the interval for printing messages (#iteration)", ) parser.add_argument( "--save-interval", type=int, default=2, help="the interval for storing models (#iteration)", ) parser.add_argument( "--eval-interval", type=int, default=5, help="the interval for evaluating models (#iteration)", ) parser.add_argument( "--print-log", type=str2bool, default=True, help="print logging or not" ) parser.add_argument( "--show-topk", type=int, default=[1, 5], nargs="+", help="which Top K accuracy will be shown", ) # feeder parser.add_argument( "--feeder", default="feeder.feeder", help="data loader will be used" ) parser.add_argument( "--num-worker", type=int, default=32, help="the number of worker for data loader", ) parser.add_argument( "--train-feeder-args", default=dict(), help="the arguments of data loader for training", ) parser.add_argument( "--test-feeder-args", default=dict(), help="the arguments of data loader for test", ) # model parser.add_argument("--model", default=None, help="the model will be used") parser.add_argument( "--model-args", type=dict, default=dict(), help="the arguments of model" ) parser.add_argument( "--weights", default=None, help="the weights for network initialization" ) parser.add_argument( "--ignore-weights", type=str, default=[], nargs="+", help="the name of weights which will be ignored in the initialization", ) # optim parser.add_argument( "--base-lr", type=float, default=0.01, help="initial learning rate" ) parser.add_argument( "--step", type=int, default=[60,80], nargs="+", help="the epoch where optimizer reduce the learning rate", ) # training parser.add_argument( "--device", type=int, default=0, nargs="+", help="the indexes of GPUs for training or testing", ) parser.add_argument("--optimizer", default="SGD", help="type of optimizer") parser.add_argument( "--nesterov", type=str2bool, default=False, help="use nesterov or not" ) parser.add_argument( "--batch-size", type=int, default=256, help="training batch size" ) parser.add_argument( "--test-batch-size", type=int, default=256, help="test batch size" ) parser.add_argument( "--start-epoch", type=int, default=0, help="start training from which epoch" ) parser.add_argument( "--num-epoch", type=int, default=80, help="stop training in which epoch" ) parser.add_argument( "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer" ) # loss parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)") parser.add_argument( "--label_count_path", default=None, type=str, help="Path to label counts (used in loss weighting)", ) parser.add_argument( "---beta", type=float, default=0.9999, help="Hyperparameter for Class balanced loss", ) parser.add_argument( "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss" ) parser.add_argument("--only_train_part", default=False) parser.add_argument("--only_train_epoch", default=0) parser.add_argument("--warm_up_epoch", default=10) parser.add_argument( "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch" ) parser.add_argument( "--class-threshold", type=float, default=0.1, help="class threshold for rejection", ) parser.add_argument( "--start-threshold", type=float, default=0.03, help="start threshold for action localization", ) parser.add_argument( "--end-threshold", type=float, default=0.055, help="end threshold for action localization", ) parser.add_argument( "--threshold-interval", type=float, default=0.005, help="threshold interval for action localization", ) return parser class Processor: """ Processor for Skeleton-based Action Recgnition """ def __init__(self, arg): self.arg = arg self.save_arg() if arg.phase == "train": if not arg.train_feeder_args["debug"]: if os.path.isdir(arg.model_saved_name): print("log_dir: ", arg.model_saved_name, "already exist") # answer = input('delete it? y/n:') answer = "y" if answer == "y": print("Deleting dir...") shutil.rmtree(arg.model_saved_name) print("Dir removed: ", arg.model_saved_name) # input('Refresh the website of tensorboard by pressing any keys') else: print("Dir not removed: ", arg.model_saved_name) self.train_writer = SummaryWriter( os.path.join(arg.model_saved_name, "train"), "train" ) self.val_writer = SummaryWriter( os.path.join(arg.model_saved_name, "val"), "val" ) else: self.train_writer = self.val_writer = SummaryWriter( os.path.join(arg.model_saved_name, "test"), "test" ) self.global_step = 0 self.load_model() self.load_optimizer() self.load_data() self.lr = self.arg.base_lr self.best_acc = 0 self.best_per_class_acc = 0 self.loss_nce = torch.nn.BCELoss() self.my_logger = Logger( os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL" ) self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+['avg']) def load_data(self): Feeder = import_class(self.arg.feeder) self.data_loader = dict() if self.arg.phase == "train": self.data_loader["train"] = torch.utils.data.DataLoader( dataset=Feeder(**self.arg.train_feeder_args), batch_size=self.arg.batch_size, shuffle=True, num_workers=self.arg.num_worker, drop_last=True, collate_fn=collate_with_padding_multi_joint, ) self.data_loader["test"] = torch.utils.data.DataLoader( dataset=Feeder(**self.arg.test_feeder_args), batch_size=self.arg.test_batch_size, shuffle=False, num_workers=self.arg.num_worker, drop_last=False, collate_fn=collate_with_padding_multi_joint, ) def load_model(self): output_device = ( self.arg.device[0] if type(self.arg.device) is list else self.arg.device ) self.output_device = output_device Model = import_class(self.arg.model) shutil.copy2(inspect.getfile(Model), self.arg.work_dir) # print(Model) self.model = Model(**self.arg.model_args).cuda(output_device) # print(self.model) self.loss_type = arg.loss if self.arg.weights: # self.global_step = int(arg.weights[:-3].split("-")[-1]) self.print_log("Load weights from {}.".format(self.arg.weights)) if ".pkl" in self.arg.weights: with open(self.arg.weights, "r") as f: weights = pickle.load(f) else: weights = torch.load(self.arg.weights) weights = OrderedDict( [ [k.split("module.")[-1], v.cuda(output_device)] for k, v in weights.items() ] ) keys = list(weights.keys()) for w in self.arg.ignore_weights: for key in keys: if w in key: if weights.pop(key, None) is not None: self.print_log( "Sucessfully Remove Weights: {}.".format(key) ) else: self.print_log("Can Not Remove Weights: {}.".format(key)) try: self.model.load_state_dict(weights) except: state = self.model.state_dict() diff = list(set(state.keys()).difference(set(weights.keys()))) print("Can not find these weights:") for d in diff: print(" " + d) state.update(weights) self.model.load_state_dict(state) if type(self.arg.device) is list: if len(self.arg.device) > 1: self.model = nn.DataParallel( self.model, device_ids=self.arg.device, output_device=output_device ) def load_optimizer(self): if self.arg.optimizer == "SGD": self.optimizer = optim.SGD( self.model.parameters(), lr=self.arg.base_lr, momentum=0.9, nesterov=self.arg.nesterov, weight_decay=self.arg.weight_decay, ) elif self.arg.optimizer == "Adam": self.optimizer = optim.Adam( self.model.parameters(), lr=self.arg.base_lr, weight_decay=self.arg.weight_decay, ) else: raise ValueError() def save_arg(self): # save arg arg_dict = vars(self.arg) if not os.path.exists(self.arg.work_dir): os.makedirs(self.arg.work_dir) with open("{}/config.yaml".format(self.arg.work_dir), "w") as f: yaml.dump(arg_dict, f) def adjust_learning_rate(self, epoch): if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam": if epoch < self.arg.warm_up_epoch: lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch else: lr = self.arg.base_lr * ( 0.1 ** np.sum(epoch >= np.array(self.arg.step)) ) for param_group in self.optimizer.param_groups: param_group["lr"] = lr return lr else: raise ValueError() def print_time(self): localtime = time.asctime(time.localtime(time.time())) self.print_log("Local current time : " + localtime) def print_log(self, str, print_time=True): if print_time: localtime = time.asctime(time.localtime(time.time())) str = "[ " + localtime + " ] " + str print(str) if self.arg.print_log: with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f: print(str, file=f) def record_time(self): self.cur_time = time.time() return self.cur_time def split_time(self): split_time = time.time() - self.cur_time self.record_time() return split_time def train(self, epoch, wb_dict, save_model=False): self.model.train() self.print_log("Training epoch: {}".format(epoch + 1)) loader = self.data_loader["train"] self.adjust_learning_rate(epoch) loss_value, batch_acc = [], [] self.train_writer.add_scalar("epoch", epoch, self.global_step) self.record_time() timer = dict(dataloader=0.001, model=0.001, statistics=0.001) process = tqdm(loader) if self.arg.only_train_part: if epoch > self.arg.only_train_epoch: print("only train part, require grad") for key, value in self.model.named_parameters(): if "PA" in key: value.requires_grad = True else: print("only train part, do not require grad") for key, value in self.model.named_parameters(): if "PA" in key: value.requires_grad = False vid_preds = [] frm_preds = [] vid_lens = [] labels = [] results = [] indexs = [] ''' Switch to FULL supervision Dataloader->Feeder ->collate_with_padding_multi_joint ''' for batch_idx, (data, label, target, mask, index, soft_label) in enumerate( process ): self.global_step += 1 # get data data = data.float().cuda(self.output_device) label = label.cuda(self.output_device) target = target.cuda(self.output_device) mask = mask.cuda(self.output_device) soft_label = soft_label.cuda(self.output_device) timer["dataloader"] += self.split_time() ''' into one hot''' ground_truth_flat = target.view(-1) one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5) ''' into one hot''' indexs.extend(index.cpu().numpy().tolist()) ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1) # forward mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask) cls_mil_loss = self.loss_nce(mil_pred, ab_labels.float()) + self.loss_nce( mil_pred_2, ab_labels.float() ) if epoch > -1: frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c") frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c") # soft_label = rearrange(soft_label, "n t c -> (n t) c") loss = cls_mil_loss * 0.1 + mvl_loss( frm_scrs, frm_scrs_2, rate=0.2, weight=0.5 ) loss += cross_entropy_loss( frm_scrs_re, one_hot_ground_truth ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth) # else: # loss = cls_mil_loss * self.arg.lambda_mil + mvl_loss( # frm_scrs, frm_scrs_2, rate=0.2, weight=0.5 # ) for i in range(data.size(0)): frm_scr = frm_scrs[i] label_ = label[i].cpu().numpy() mask_ = mask[i].cpu().numpy() vid_len = mask_.sum() frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len] vid_pred = mil_pred[i].detach().cpu().numpy() results.append(frm_pred) vid_preds.append(vid_pred) frm_preds.append(frm_pred) vid_lens.append(vid_len) labels.append(label_) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss_value.append(loss.data.item()) timer["model"] += self.split_time() vid_preds = np.array(vid_preds) frm_preds = np.array(frm_preds) vid_lens = np.array(vid_lens) labels = np.array(labels) loader.dataset.label_update(results, indexs) cmap = cmAP(vid_preds, labels) self.train_writer.add_scalar("acc", cmap, self.global_step) self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step) # statistics self.lr = self.optimizer.param_groups[0]["lr"] self.train_writer.add_scalar("lr", self.lr, self.global_step) timer["statistics"] += self.split_time() # statistics of time consumption and loss self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value))) self.print_log("\tAcc score: {:.3f}%".format(cmap)) # Log wb_dict["train loss"] = np.mean(loss_value) wb_dict["train acc"] = cmap if save_model: state_dict = self.model.state_dict() weights = OrderedDict( [[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()] ) torch.save( weights, self.arg.model_saved_name + str(epoch) + ".pt", ) return wb_dict @torch.no_grad() def eval( self, epoch, wb_dict, loader_name=["test"], ): self.model.eval() self.print_log("Eval epoch: {}".format(epoch + 1)) vid_preds = [] frm_preds = [] vid_lens = [] labels = [] for ln in loader_name: loss_value = [] step = 0 process = tqdm(self.data_loader[ln]) for batch_idx, (data, label, target, mask, index, soft_label) in enumerate( process ): data = data.float().cuda(self.output_device) label = label.cuda(self.output_device) mask = mask.cuda(self.output_device) ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1) # forward mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask) '''Loc LOSS''' target = target.cuda(self.output_device) ''' into one hot''' ground_truth_flat = target.view(-1) one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5) ''' into one hot''' frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c") frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c") '''Loc LOSS''' cls_mil_loss = self.loss_nce( mil_pred, ab_labels.float() ) + self.loss_nce(mil_pred_2, ab_labels.float()) loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5) loss = cls_mil_loss * self.arg.lambda_mil + loss_co '''Loc LOSS''' loss += cross_entropy_loss( frm_scrs_re, one_hot_ground_truth ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth) '''Loc LOSS''' loss_value.append(loss.data.item()) for i in range(data.size(0)): frm_scr = frm_scrs[i] vid_pred = mil_pred[i] label_ = label[i].cpu().numpy() mask_ = mask[i].cpu().numpy() vid_len = mask_.sum() frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len] vid_pred = vid_pred.cpu().numpy() vid_preds.append(vid_pred) frm_preds.append(frm_pred) vid_lens.append(vid_len) labels.append(label_) step += 1 vid_preds = np.array(vid_preds) frm_preds = np.array(frm_preds) vid_lens = np.array(vid_lens) labels = np.array(labels) cmap = cmAP(vid_preds, labels) score = cmap loss = np.mean(loss_value) dmap, iou = dsmAP( vid_preds, frm_preds, vid_lens, self.arg.test_feeder_args["data_path"], self.arg, multi=True, ) print("Classification map %f" % cmap) for item in list(zip(iou, dmap)): print("Detection map @ %f = %f" % (item[0], item[1])) self.my_logger.append([epoch + 1, cmap] + dmap+ [np.mean(dmap)]) wb_dict["val loss"] = loss wb_dict["val acc"] = score if score > self.best_acc: self.best_acc = score print("Acc score: ", score, " model: ", self.arg.model_saved_name) if self.arg.phase == "train": self.val_writer.add_scalar("loss", loss, self.global_step) self.val_writer.add_scalar("acc", score, self.global_step) self.print_log( "\tMean {} loss of {} batches: {}.".format( ln, len(self.data_loader[ln]), np.mean(loss_value) ) ) self.print_log("\tAcc score: {:.3f}%".format(score)) return wb_dict def start(self): wb_dict = {} if self.arg.phase == "train": self.print_log("Parameters:\n{}\n".format(str(vars(self.arg)))) self.global_step = ( self.arg.start_epoch * len(self.data_loader["train"]) / self.arg.batch_size ) for epoch in range(self.arg.start_epoch, self.arg.num_epoch): save_model = ((epoch + 1) % self.arg.save_interval == 0) or ( epoch + 1 == self.arg.num_epoch ) wb_dict = {"lr": self.lr} # Train wb_dict = self.train(epoch, wb_dict, save_model=save_model) if epoch%10==0: # Eval. on val set wb_dict = self.eval(epoch, wb_dict, loader_name=["test"]) # Log stats. for this epoch print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict)) print( "best accuracy: ", self.best_acc, " model_name: ", self.arg.model_saved_name, ) elif self.arg.phase == "test": if not self.arg.test_feeder_args["debug"]: wf = self.arg.model_saved_name + "_wrong.txt" rf = self.arg.model_saved_name + "_right.txt" else: wf = rf = None if self.arg.weights is None: raise ValueError("Please appoint --weights.") self.arg.print_log = False self.print_log("Model: {}.".format(self.arg.model)) self.print_log("Weights: {}.".format(self.arg.weights)) wb_dict = self.eval( epoch=0, wb_dict=wb_dict, loader_name=["test"], wrong_file=wf, result_file=rf, ) print("Inference metrics: ", wb_dict) self.print_log("Done.\n") def str2bool(v): if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def import_class(name): components = name.split(".") mod = __import__(components[0]) for comp in components[1:]: mod = getattr(mod, comp) return mod if __name__ == "__main__": parser = get_parser() # load arg form config file p = parser.parse_args() if p.config is not None: with open(p.config, "r") as f: default_arg = yaml.safe_load(f) key = vars(p).keys() for k in default_arg.keys(): if k not in key: print("WRONG ARG: {}".format(k)) assert k in key parser.set_defaults(**default_arg) arg = parser.parse_args() print("BABEL Action Recognition") print("Config: ", arg) init_seed(arg.seed) processor = Processor(arg) processor.start()