workshop / LaSA /train.py
qiushuocheng's picture
update
d53e869
import argparse
import os
import random
import time
import sys
import clip
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from libs import models
from libs.checkpoint import resume, save_checkpoint
from libs.class_id_map import get_n_classes
from libs.class_weight import get_class_weight, get_pos_weight, get_pos_weight_BABEL3
from libs.config import get_config
from libs.dataset import ActionSegmentationDataset, collate_fn
from libs.helper import train, validate
from libs.loss_fn import ActionSegmentationLoss, BoundaryRegressionLoss, KLLoss
from libs.optimizer import get_optimizer
from libs.transformer import TempDownSamp, ToTensor
from prompt.text_prompt import TextCLIP, text_prompt_for_class, text_prompt_for_joint
import numpy as np
def get_arguments() -> argparse.Namespace:
"""
parse all the arguments from command line inteface
return a list of parsed arguments
"""
parser = argparse.ArgumentParser(
description="train a network for action segmentation"
)
parser.add_argument("--dataset", type=str, default="PKU-view", help="name of the dataset")
parser.add_argument("--result_path", type=str, default="./result", help="path of a result")
parser.add_argument("--cuda", type=int, default= 0, help="cuda id")
parser.add_argument(
"--resume", action="store_true", help="Add --resume option if you start training from checkpoint.",
)
return parser.parse_args()
def import_class(import_str):
mod_str, _sep, class_str = import_str.rpartition('.')
__import__(mod_str)
try:
return getattr(sys.modules[mod_str], class_str)
except AttributeError:
raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info())))
def change_label_score(best_test, train_loss, epoch, cls_acc, edit_score, f1s,map = [0]):
best_test['train_loss'] = train_loss
best_test['epoch'] = epoch
best_test['cls_acc'] = cls_acc
best_test['edit'] = edit_score
best_test['f1s@0.1'] = f1s[0]
best_test['f1s@0.25'] = f1s[1]
best_test['f1s@0.5'] = f1s[2]
best_test['f1s@0.75'] = f1s[3]
best_test['f1s@0.9'] = f1s[4]
best_test['map'] = np.mean(map)
def main() -> None:
start_start = time.time()
# argparser
args = get_arguments()
dataset_name = args.dataset
device_num = args.cuda
# configuration
config = get_config(f"config/{dataset_name}/config.yaml")
result_path = os.path.join(args.result_path, config.dataset, 'split' + str(config.split))
print('\n---------------------------result_path---------------------------\n')
print('result_path:',result_path) #'./result/LARA/DeST_tcn/split1'
if not os.path.exists(result_path):
os.makedirs(result_path)
with open(f'{result_path}/scores.txt', "w") as file:
file.write(f'The result printed:\n')
seed = config.seed
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# cpu or cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.backends.cudnn.benchmark = True
device = device_num #0
output_device = device_num[0] if type(device_num) is list else device_num
torch.cuda.set_device(output_device)
if type(device) is list:
# 设置环境变量 CUDA_VISIBLE_DEVICES
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_num))
else:
os.environ["CUDA_VISIBLE_DEVICES"] = f'{device_num}'
current_device = torch.cuda.current_device()
print(f"Currently using GPU {current_device}")
# Dataloader
# Temporal downsampling is applied to only videos in LARA
downsamp_rate = 4 if config.dataset == "LARA" else 1
train_data = ActionSegmentationDataset(
config.dataset,
transform=Compose([ToTensor(), TempDownSamp(downsamp_rate)]),
mode="trainval" if not config.param_search else "training",
split=config.split,
dataset_dir=config.dataset_dir,
csv_dir=config.csv_dir,
)
train_loader = DataLoader(
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers, #4
drop_last=True if config.batch_size > 1 else False,
collate_fn=collate_fn,
)
# if you do validation to determine hyperparams
if config.param_search:
val_data = ActionSegmentationDataset(
config.dataset,
transform=Compose([ToTensor(), TempDownSamp(downsamp_rate)]),
mode="validation",
split=config.split,
dataset_dir=config.dataset_dir,
csv_dir=config.csv_dir,
)
val_loader = DataLoader(
val_data,
batch_size=1,
shuffle=False,
num_workers=config.num_workers,
collate_fn=collate_fn,
)
# load model
print("---------- Loading Model ----------")
n_classes = get_n_classes(config.dataset, dataset_dir=config.dataset_dir) #几类
class_text_list = text_prompt_for_class(dataset_name,"detail") #the index of related sentences (cls, 77)
joint_text_list = text_prompt_for_joint(dataset_name, "detail")
Model = import_class("libs.models.LaSA.Model")
model = Model( #直接建立并传参
in_channel=config.in_channel, #12
n_features=config.n_features, #64
n_classes=n_classes, #8
n_stages=config.n_stages, #4
n_layers=config.n_layers, #10
n_refine_layers=config.n_refine_layers, #10
n_stages_asb=config.n_stages_asb, #2
n_stages_brb=config.n_stages_brb, #3
SFI_layer=config.SFI_layer, #{1,2,3,4,5,6,7,8,9}
dataset=config.dataset, #LARA
)
model_, preprocess = clip.load("ViT-B/32", "cuda" if torch.cuda.is_available() else "cpu")
model_text = TextCLIP(model_)
model_text = model_text.cuda(output_device)
# send the model to cuda/cpu
model.to(output_device)
if type(device) is list:
if len(device) > 1: #Placing the model in multiple GPUS, currently unavailable
model = nn.DataParallel(model, device_ids=device, output_device=output_device)
model_text = nn.DataParallel(model_text, device_ids=device, output_device=output_device)
optimizer = get_optimizer(
config.optimizer,
model,
config.learning_rate,
momentum=config.momentum,
dampening=config.dampening,
weight_decay=config.weight_decay,
nesterov=config.nesterov,
) #Adam or SGD, we only update the parameters of the model, without updating the text model
# resume if you want
columns = ["epoch", "lr", "train_loss"]
# if you do validation to determine hyperparams
if config.param_search:
columns += ["val_loss", "cls_acc", "edit"]
columns += [
"f1s@{}".format(config.iou_thresholds[i])
for i in range(len(config.iou_thresholds))
]
columns += ["bound_acc", "precision", "recall", "bound_f1s"]
begin_epoch = 0
best_loss = float("inf")
# Define temporary variables for evaluation scores
best_test_acc = {'epoch':0,'train_loss':0,'cls_acc':0,'edit':0,'f1s@0.1':0,'f1s@0.25':0,'f1s@0.5':0,'f1s@0.75':0,'f1s@0.9':0,'mAP':0}
best_test_F1_10 = best_test_acc.copy()
best_test_F1_50 = best_test_acc.copy()
best_test_MAP = best_test_acc.copy()
log = pd.DataFrame(columns=columns)
# ['epoch', 'lr', 'train_loss', 'val_loss', 'cls_acc', 'edit', 'f1s@0.1', 'f1s@0.25', 'f1s@0.5', 'f1s@0.75', 'f1s@0.9', 'bound_acc', 'precision', 'recall', 'bound_f1s'] [Columns: [epoch, lr, train_loss, val_loss, cls_acc, edit, f1s@0.1, f1s@0.25, f1s@0.5, f1
if args.resume:
if os.path.exists(os.path.join(result_path, "checkpoint.pth")):
checkpoint = resume(result_path, model, optimizer)
begin_epoch, model, optimizer, best_loss = checkpoint
log = pd.read_csv(os.path.join(result_path, "log.csv"))
print("training will start from {} epoch".format(begin_epoch))
else:
print("there is no checkpoint at the result folder")
#obtain category weights, with parameters including dataset, data partitioning, dataset path, CSV path, and pattern
if config.class_weight:
class_weight = get_class_weight(
config.dataset,
split=config.split,
dataset_dir=config.dataset_dir,
csv_dir=config.csv_dir,
mode="training" if config.param_search else "trainval",
)
class_weight = class_weight.to(output_device)
else:
class_weight = None
criterion_cls = ActionSegmentationLoss(
ce=config.ce,
focal=config.focal,
tmse=config.tmse,
gstmse=config.gstmse,
weight=class_weight,
ignore_index=255,
ce_weight=config.ce_weight,
focal_weight=config.focal_weight,
tmse_weight=config.tmse_weight,
gstmse_weight=config.gstmse,
).cuda(output_device) #Including cross entropy loss and Gaussian smoothing loss
pos_weight = get_pos_weight(
dataset=config.dataset,
split=config.split,
csv_dir=config.csv_dir,
mode="training" if config.param_search else "trainval",
).to(output_device)
criterion_bound = BoundaryRegressionLoss(pos_weight=pos_weight).cuda(output_device) #a binary cross entropy loss
criterion_contrast = KLLoss().cuda(output_device) #contrastive loss
# train and validate model
print("---------- Start training ----------")
avg_cls_acc=0
avg_edit_score=0
avg_segment_f1s=[0,0,0,0,0]
avg_bound_acc=0
avg_precision=0
avg_recall=0
avg_bound_f1s=0
for epoch in range(begin_epoch, config.max_epoch):
# training
start = time.time()
train_loss = train(
train_loader,
model,
model_text,
class_text_list,
joint_text_list,
criterion_cls,
criterion_bound,
criterion_contrast,
config.lambda_b,
optimizer,
dataset_name,
device,output_device
)
train_time = (time.time() - start) / 60
# if you do validation to determine hyperparams
if config.param_search:
start = time.time()
(
val_loss,
cls_acc,
edit_score,
segment_f1s,
maps,
bound_acc,
precision,
recall,
bound_f1s,
) = validate(
val_loader,
model,
model_text,
joint_text_list,
criterion_cls,
criterion_bound,
config.lambda_b,
device,output_device,
config.dataset,
config.dataset_dir,
config.iou_thresholds,
config.boundary_th,
config.tolerance,
config.refinement_method,
)
if (epoch>=config.max_epoch-20):
avg_cls_acc += cls_acc/20
avg_edit_score += edit_score/20
avg_segment_f1s = [a + b/20 for a, b in zip(avg_segment_f1s,segment_f1s)]
avg_bound_acc += bound_acc/20
avg_precision += precision/20
avg_recall += recall/20
avg_bound_f1s += bound_f1s/20
if (epoch >0):
# save a model if top1 cls_acc is higher than ever
if best_loss > val_loss:
best_loss = val_loss
if cls_acc > best_test_acc['cls_acc']:
change_label_score(best_test_acc, train_loss, epoch, cls_acc, edit_score, segment_f1s)
torch.save(
model.state_dict(),
os.path.join(result_path, 'best_test_acc_model.prm')
)
if segment_f1s[0] > best_test_F1_10['f1s@0.1']:
change_label_score(best_test_F1_10, train_loss, epoch, cls_acc, edit_score, segment_f1s)
torch.save(
model.state_dict(),
os.path.join(result_path, 'best_test_F1_0.1_model.prm')
)
if segment_f1s[2] > best_test_F1_50['f1s@0.5']:
change_label_score(best_test_F1_50, train_loss, epoch, cls_acc, edit_score, segment_f1s)
torch.save(
model.state_dict(),
os.path.join(result_path, 'best_test_F1_0.5_model.prm')
)
if np.mean(maps) > best_test_MAP['mAP']:
change_label_score(best_test_MAP, train_loss, epoch, cls_acc, edit_score, segment_f1s,maps)
torch.save(
model.state_dict(),
os.path.join(result_path, 'best_test_map_model.prm')
)
# save checkpoint every epoch
save_checkpoint(result_path, epoch, model, optimizer, best_loss)
# write logs to dataframe and csv file
tmp = [epoch, optimizer.param_groups[0]["lr"], train_loss]
# if you do validation to determine hyperparams
if config.param_search:
tmp += [
val_loss,
cls_acc,
edit_score,
]
tmp += segment_f1s
tmp += [
bound_acc,
precision,
recall,
bound_f1s,
]
tmp_df = pd.DataFrame(tmp, index=log.columns).T
log = pd.concat([log, tmp_df], ignore_index=True)
log.to_csv(os.path.join(result_path, "log.csv"))
val_time = (time.time() - start) / 60
eta_time = (config.max_epoch-epoch)*(train_time+val_time)
if config.param_search:
# if you do validation to determine hyperparams
print(
'epoch: {}, lr: {:.4f}, train_time: {:.2f}min, train loss: {:.4f}, val_time: {:.2f}min, eta_time: {:.2f}min, \nval_loss: {:.4f}, acc: {:.2f}, edit: {:.2f}, F1@0.1: {:.2f}, F1@0.25: {:.2f}, F1@0.5: {:.2f}, bound_acc: {:.2f}, bound_f1: {:.2f},map: {:.2f}'
.format(epoch, optimizer.param_groups[0]['lr'], train_time, train_loss, val_time, eta_time, val_loss, cls_acc, \
edit_score, segment_f1s[0],segment_f1s[1], segment_f1s[2],bound_acc,bound_f1s, maps)
)
with open(f'{result_path}/scores.txt', "a+") as file:
file.write(
'epoch: {}, lr: {:.4f}, train_time: {:.2f}min, train loss: {:.4f}, val_time: {:.2f}min, eta_time: {:.2f}min, \nval_loss: {:.4f}, acc: {:.2f}, edit: {:.2f}, F1@0.1: {:.2f}, F1@0.25: {:.2f}, F1@0.5: {:.2f}, bound_acc: {:.2f}, bound_f1: {:.2f},map: {:.2f}\n'
.format(epoch, optimizer.param_groups[0]['lr'], train_time, train_loss, val_time, eta_time, val_loss, cls_acc, \
edit_score, segment_f1s[0],segment_f1s[1], segment_f1s[2],bound_acc,bound_f1s,maps)
)
else:
print(
"epoch: {}\tlr: {:.4f}\ttrain loss: {:.4f}".format(
epoch, optimizer.param_groups[0]["lr"], train_loss
)
)
# delete checkpoint
os.remove(os.path.join(result_path, "checkpoint.pth"))
print('\n---------------------------best_test_acc---------------------------\n')
print('{}'.format(best_test_acc))
print('\n---------------------------best_test_F1_10---------------------------\n')
print('{}'.format(best_test_F1_10))
print('\n---------------------------best_test_F1_50---------------------------\n')
print('{}'.format(best_test_F1_50))
print('\n---------------------------best_test_MAP---------------------------\n')
print('{}'.format(best_test_MAP))
print('\n---------------------------all_train_time---------------------------\n')
print('all_train_time: {:.2f}min'.format((time.time() - start_start) / 60))
with open(f'{result_path}/scores.txt', "a+") as file:
file.write('\n---------------------------best_test_acc---------------------------\n')
file.write('{}'.format(best_test_acc))
file.write('\n---------------------------best_test_F1_10---------------------------\n')
file.write('{}'.format(best_test_F1_10))
file.write('\n---------------------------best_test_F1_50---------------------------\n')
file.write('{}'.format(best_test_F1_50))
file.write('\n---------------------------all_train_time---------------------------\n')
file.write('all_train_time: {:.2f}min'.format((time.time() - start_start) / 60))
# print('avg_acc: {:.2f}, avg_edit: {:.2f}, avg_f1@10: {:.2f}, avg_f1@25: {:.2f}, avg_f1@50: {:.2f}, avg_bound_acc: {:.2f}, avg_precision: {:.2f}, avg_recall: {:.2f}, avg_bound_f1s: {:.2f}'
# .format(avg_cls_acc, avg_edit_score, avg_segment_f1s[0],avg_segment_f1s[1],avg_segment_f1s[2], avg_bound_acc, avg_precision, avg_recall, avg_bound_f1s)
# )
#
# with open(f'{result_path}/scores.txt', "a+") as file:
# file.write(
# 'avg_acc: {:.2f}, avg_edit: {:.2f}, avg_f1@10: {:.2f}, avg_f1@25: {:.2f}, avg_f1@50: {:.2f}, avg_bound_acc: {:.2f}, avg_precision: {:.2f}, avg_recall: {:.2f}, avg_bound_f1s: {:.2f}\n'
# .format(avg_cls_acc, avg_edit_score, avg_segment_f1s[0],avg_segment_f1s[1],avg_segment_f1s[2], avg_bound_acc, avg_precision, avg_recall, avg_bound_f1s)
# )
best_test_acc = pd.DataFrame.from_dict(best_test_acc, orient='index').T
best_test_F1_10 = pd.DataFrame.from_dict(best_test_F1_10, orient='index').T
best_test_F1_50 = pd.DataFrame.from_dict(best_test_F1_50, orient='index').T
log = pd.concat([log, best_test_acc], ignore_index=True)
log = pd.concat([log, best_test_F1_10], ignore_index=True)
log = pd.concat([log, best_test_F1_50], ignore_index=True)
log.to_csv(os.path.join(result_path, 'log.csv'), index=False)
print("Done!")
if __name__ == "__main__":
main()