| | import os |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from dataset.ucf_jhmdb import UCF_JHMDB_Dataset |
| | from dataset.ava import AVA_Dataset |
| | from dataset.transforms import Augmentation, BaseTransform |
| |
|
| | from evaluator.ucf_jhmdb_evaluator import UCF_JHMDB_Evaluator |
| | from evaluator.ava_evaluator import AVA_Evaluator |
| |
|
| |
|
| | def build_dataset(d_cfg, args, is_train=False): |
| | """ |
| | d_cfg: dataset config |
| | """ |
| | |
| | augmentation = Augmentation( |
| | img_size=d_cfg['train_size'], |
| | jitter=d_cfg['jitter'], |
| | hue=d_cfg['hue'], |
| | saturation=d_cfg['saturation'], |
| | exposure=d_cfg['exposure'] |
| | ) |
| | basetransform = BaseTransform( |
| | img_size=d_cfg['test_size'], |
| | ) |
| |
|
| | |
| | if args.dataset in ['ucf24', 'jhmdb21']: |
| | data_dir = os.path.join(args.root, 'ucf24') |
| |
|
| | |
| | dataset = UCF_JHMDB_Dataset( |
| | data_root=data_dir, |
| | dataset=args.dataset, |
| | img_size=d_cfg['train_size'], |
| | transform=augmentation, |
| | is_train=is_train, |
| | len_clip=args.len_clip, |
| | sampling_rate=d_cfg['sampling_rate'] |
| | ) |
| | num_classes = dataset.num_classes |
| |
|
| | |
| | evaluator = UCF_JHMDB_Evaluator( |
| | data_root=data_dir, |
| | dataset=args.dataset, |
| | model_name=args.version, |
| | metric='fmap', |
| | img_size=d_cfg['test_size'], |
| | len_clip=args.len_clip, |
| | batch_size=args.test_batch_size, |
| | conf_thresh=0.01, |
| | iou_thresh=0.5, |
| | gt_folder=d_cfg['gt_folder'], |
| | save_path='./evaluator/eval_results/', |
| | transform=basetransform, |
| | collate_fn=CollateFunc() |
| | ) |
| |
|
| | elif args.dataset == 'ava_v2.2': |
| | |
| | data_dir = args.root |
| | |
| | |
| | dataset = AVA_Dataset( |
| | cfg=d_cfg, |
| | data_root=data_dir, |
| | is_train=True, |
| | img_size=d_cfg['train_size'], |
| | transform=augmentation, |
| | len_clip=args.len_clip, |
| | sampling_rate=d_cfg['sampling_rate'] |
| | ) |
| | num_classes = 3 |
| |
|
| | |
| | evaluator = AVA_Evaluator( |
| | d_cfg=d_cfg, |
| | data_root=data_dir, |
| | img_size=d_cfg['test_size'], |
| | len_clip=args.len_clip, |
| | sampling_rate=d_cfg['sampling_rate'], |
| | batch_size=args.test_batch_size, |
| | transform=basetransform, |
| | collate_fn=CollateFunc(), |
| | full_test_on_val=False, |
| | version='v2.2' |
| | ) |
| |
|
| | else: |
| | print('unknow dataset !! Only support ucf24 & jhmdb21 & ava_v2.2 !!') |
| | exit(0) |
| |
|
| | print('==============================') |
| | print('Training model on:', args.dataset) |
| | print('The dataset size:', len(dataset)) |
| |
|
| | if not args.eval: |
| | |
| | evaluator = None |
| |
|
| | return dataset, evaluator, num_classes |
| |
|
| |
|
| | def build_dataloader(args, dataset, batch_size, collate_fn=None, is_train=False): |
| | if is_train: |
| | |
| | if args.distributed: |
| | sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| | else: |
| | sampler = torch.utils.data.RandomSampler(dataset) |
| |
|
| | batch_sampler_train = torch.utils.data.BatchSampler(sampler, |
| | batch_size, |
| | drop_last=True) |
| | |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset=dataset, |
| | batch_sampler=batch_sampler_train, |
| | collate_fn=collate_fn, |
| | num_workers=args.num_workers, |
| | pin_memory=True |
| | ) |
| | else: |
| | |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset=dataset, |
| | shuffle=False, |
| | collate_fn=collate_fn, |
| | num_workers=args.num_workers, |
| | drop_last=False, |
| | pin_memory=True |
| | ) |
| | |
| | return dataloader |
| | |
| |
|
| | def load_weight(model, path_to_ckpt=None): |
| | if path_to_ckpt is None: |
| | print('No trained weight ..') |
| | return model |
| | |
| | checkpoint = torch.load(path_to_ckpt, map_location='cpu') |
| | |
| | checkpoint_state_dict = checkpoint.pop("model") |
| | |
| | model_state_dict = model[0].state_dict() |
| | |
| | for k in list(checkpoint_state_dict.keys()): |
| | if k in model_state_dict: |
| | shape_model = tuple(model_state_dict[k].shape) |
| | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
| | if shape_model != shape_checkpoint: |
| | checkpoint_state_dict.pop(k) |
| | else: |
| | checkpoint_state_dict.pop(k) |
| | print(k) |
| |
|
| | model[0].load_state_dict(checkpoint_state_dict) |
| | print('Finished loading model!') |
| |
|
| | return model[0] |
| |
|
| |
|
| | def is_parallel(model): |
| | |
| | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |
| |
|
| |
|
| | class CollateFunc(object): |
| | def __call__(self, batch): |
| | batch_frame_id = [] |
| | batch_key_target = [] |
| | batch_video_clips = [] |
| |
|
| | for sample in batch: |
| | key_frame_id = sample[0] |
| | video_clip = sample[1] |
| | key_target = sample[2] |
| | |
| | batch_frame_id.append(key_frame_id) |
| | batch_video_clips.append(video_clip) |
| | batch_key_target.append(key_target) |
| |
|
| | |
| | batch_video_clips = torch.stack(batch_video_clips) |
| | |
| | return batch_frame_id, batch_video_clips, batch_key_target |
| |
|