|
|
|
|
|
|
|
import json |
|
import torch |
|
import torch.nn as nn |
|
import torch.backends.cudnn as cudnn |
|
from einops import rearrange |
|
from transformers import BertTokenizer |
|
from torchvision import transforms |
|
from torchvision.transforms._transforms_video import ( |
|
NormalizeVideo, |
|
) |
|
|
|
from svitt.model import SViTT |
|
from svitt.config import load_cfg, setup_config |
|
from svitt.base_dataset import read_frames_cv2_egoclip |
|
|
|
|
|
class VideoModel(nn.Module): |
|
""" Base model for video understanding based on SViTT architecture. """ |
|
def __init__(self, config): |
|
""" Initializes the model. |
|
Parameters: |
|
config: config file |
|
""" |
|
super().__init__() |
|
self.cfg = load_cfg(config) |
|
self.model = self.build_model() |
|
use_gpu = torch.cuda.is_available() |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if use_gpu: |
|
self.model = self.model.to(self.device) |
|
self.templates = ['{}'] |
|
self.dataset = self.cfg['data']['dataset'] |
|
self.eval() |
|
|
|
def build_model(self): |
|
cfg = self.cfg |
|
if cfg['model'].get('pretrain', False): |
|
ckpt_path = cfg['model']['pretrain'] |
|
else: |
|
raise Exception('no checkpoint found') |
|
|
|
if cfg['model'].get('config', False): |
|
config_path = cfg['model']['config'] |
|
else: |
|
raise Exception('no model config found') |
|
|
|
self.model_cfg = setup_config(config_path) |
|
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) |
|
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) |
|
|
|
print(f"Loading checkpoint from {ckpt_path}") |
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
if "bert" in key: |
|
encoder_key = key.replace("bert.", "") |
|
state_dict[encoder_key] = state_dict[key] |
|
|
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
def eval(self): |
|
cudnn.benchmark = True |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.model.eval() |
|
|
|
|
|
class VideoCLSModel(VideoModel): |
|
""" Video model for video classification tasks (Charades-Ego, EGTEA). """ |
|
def __init__(self, config, sample_videos): |
|
super().__init__(config) |
|
self.sample_videos = sample_videos |
|
self.video_transform = self.init_video_transform() |
|
|
|
|
|
|
|
|
|
def init_video_transform(self, |
|
input_res=224, |
|
center_crop=256, |
|
norm_mean=(0.485, 0.456, 0.406), |
|
norm_std=(0.229, 0.224, 0.225), |
|
): |
|
print('Video Transform is used!') |
|
normalize = NormalizeVideo(mean=norm_mean, std=norm_std) |
|
return transforms.Compose( |
|
[ |
|
transforms.Resize(center_crop), |
|
transforms.CenterCrop(center_crop), |
|
transforms.Resize(input_res), |
|
normalize, |
|
] |
|
) |
|
|
|
def load_data(self, idx): |
|
num_frames = self.model_cfg.video_input.num_frames |
|
video_paths = self.sample_videos[idx] |
|
clips = [None] * len(video_paths) |
|
for i, path in enumerate(video_paths): |
|
imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform') |
|
imgs = imgs.transpose(0, 1) |
|
imgs = self.video_transform(imgs) |
|
imgs = imgs.transpose(0, 1) |
|
clips[i] = imgs |
|
return torch.stack(clips) |
|
|
|
def load_meta(self, idx=None): |
|
filename = f"{self.cfg['data']['root']}/{idx}/meta.json" |
|
with open(filename, "r") as f: |
|
meta = json.load(f) |
|
return meta |
|
|
|
@torch.no_grad() |
|
def get_text_features(self, text): |
|
print('=> Extracting text features') |
|
embeddings = self.tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.model_cfg.max_txt_l.video, |
|
return_tensors="pt", |
|
).to(self.device) |
|
_, class_embeddings = self.model.encode_text(embeddings) |
|
return class_embeddings |
|
|
|
@torch.no_grad() |
|
def forward(self, idx, text=None): |
|
print('=> Start forwarding') |
|
meta = self.load_meta(idx) |
|
clips = self.load_data(idx) |
|
if text is None: |
|
text = meta["text"][4:] |
|
text_features = self.get_text_features(text) |
|
target = meta["correct"] |
|
|
|
|
|
pooled_image_feat_all = [] |
|
for i in range(clips.shape[0]): |
|
|
|
images = clips[i,:].unsqueeze(0).to(self.device) |
|
bsz = images.shape[0] |
|
|
|
_, pooled_image_feat, *outputs = self.model.encode_image(images) |
|
if pooled_image_feat.ndim == 3: |
|
pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz) |
|
else: |
|
pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz) |
|
|
|
pooled_image_feat_all.append(pooled_image_feat) |
|
|
|
pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0) |
|
similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0] |
|
return similarity.argmax(), target |
|
|
|
@torch.no_grad() |
|
def predict(self, idx, text=None): |
|
output, target = self.forward(idx, text) |
|
return output.cpu().numpy(), target |
|
|
|
|