### demo.py # Define model classes for inference. ### import json import numpy as np import os import pandas as pd import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import torchvision.transforms._transforms_video as transforms_video from sklearn.metrics import confusion_matrix from einops import rearrange from transformers import BertTokenizer from svitt.model import SViTT from svitt.datasets import VideoClassyDataset from svitt.video_transforms import Permute from svitt.config import load_cfg, setup_config from svitt.evaluation_charades import charades_map from svitt.evaluation import get_mean_accuracy class VideoModel(nn.Module): """ Base model for video understanding based on SViTT architecture. """ def __init__(self, config): """ Initializes the model. Parameters: config: config file """ super(VideoModel, self).__init__() self.cfg = load_cfg(config) self.model = self.build_model() self.templates = ['{}'] self.dataset = self.cfg['data']['dataset'] self.eval() def build_model(self): cfg = self.cfg if cfg['model'].get('pretrain', False): ckpt_path = cfg['model']['pretrain'] else: raise Exception('no checkpoint found') if cfg['model'].get('config', False): config_path = cfg['model']['config'] else: raise Exception('no model config found') self.model_cfg = setup_config(config_path) self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) print(f"Loading checkpoint from {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint["model"] # fix for zero-shot evaluation for key in list(state_dict.keys()): if "bert" in key: encoder_key = key.replace("bert.", "") state_dict[encoder_key] = state_dict[key] if torch.cuda.is_available(): model.cuda() model.load_state_dict(state_dict, strict=False) return model def eval(self): cudnn.benchmark = True for p in self.model.parameters(): p.requires_grad = False self.model.eval() class VideoCLSModel(VideoModel): """ Video model for video classification tasks (Charades-Ego, EGTEA). """ def __init__(self, config): super(VideoCLSModel, self).__init__(config) self.labels, self.mapping_vn2act = self.gen_label_map() self.text_features = self.get_text_features() def gen_label_map(self): labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') if os.path.isfile(labelmap): print(f"=> Loading label maps from {labelmap}") meta = json.load(open(labelmap, 'r')) labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] else: from svitt.preprocess import generate_label_map labels, mapping_vn2act = generate_label_map(self.dataset) meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} meta_dir = f'meta/{self.dataset}' if not os.path.exists(meta_dir): os.makedirs(meta_dir) json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") return labels, mapping_vn2act def load_data(self, idx=None): print(f"=> Creating dataset") cfg, dataset = self.cfg, self.dataset data_cfg = cfg['data'] crop_size = 224 val_transform = transforms.Compose([ Permute([3, 0, 1, 2]), # T H W C -> C T H W transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms_video.NormalizeVideo( mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305], ), ]) if idx is None: metadata_val = data_cfg['metadata_val'] else: metadata_val = data_cfg['metadata_val'].format(idx) if dataset in ['charades_ego', 'egtea']: val_dataset = VideoClassyDataset( dataset, data_cfg['root'], metadata_val, transform=val_transform, is_training=False, label_mapping=self.mapping_vn2act, is_trimmed=False, num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], sparse_sample=data_cfg['sparse_sample'], ) else: raise NotImplementedError val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True, sampler=None, drop_last=False ) return val_loader @torch.no_grad() def get_text_features(self): print('=> Extracting text features') embeddings = self.tokenizer( self.labels, padding="max_length", truncation=True, max_length=self.model_cfg.max_txt_l.video, return_tensors="pt", ) _, class_embeddings = self.model.encode_text(embeddings) return class_embeddings @torch.no_grad() def forward(self, idx=None): print('=> Start forwarding') val_loader = self.load_data(idx) all_outputs = [] all_targets = [] for i, values in enumerate(val_loader): images = values[0] target = values[1] if torch.cuda.is_available(): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # encode images images = rearrange(images, 'b c k h w -> b k c h w') dims = images.shape images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1]) image_features, _ = self.model.encode_image(images) if image_features.ndim == 3: image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1) else: image_features = rearrange(image_features, '(b k) d -> b k d', b=1) # cosine similarity as logits similarity = self.model.get_sim(image_features, self.text_features)[0] all_outputs.append(similarity.cpu()) all_targets.append(target.cpu()) all_outputs = torch.cat(all_outputs) all_targets = torch.cat(all_targets) return all_outputs, all_targets @torch.no_grad() def predict(self, idx=0): all_outputs, all_targets = self.forward(idx) preds, targets = all_outputs.numpy(), all_targets.numpy() #sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0] sel = 5 df = pd.DataFrame(self.labels) pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() pred_action = sorted([x[0] for x in pred_action]) gt_action = sorted([x[0] for x in gt_action]) return pred_action, gt_action @torch.no_grad() def evaluate(self): all_outputs, all_targets = self.forward() preds, targets = all_outputs.numpy(), all_targets.numpy() if self.dataset == 'charades_ego': m_ap, _, m_aps = charades_map(preds, targets) print('mAP = {:.3f}'.format(m_ap)) elif self.dataset == 'egtea': cm = confusion_matrix(targets, preds.argmax(axis=1)) mean_class_acc, acc = get_mean_accuracy(cm) print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) else: raise NotImplementedError