ChatVID / model /utils /extract_clip_feature.py
Yiqin's picture
init
6ef31de
raw
history blame
No virus
3.64 kB
import clip
import numpy as np
import torch
from mmaction.datasets.transforms import (CenterCrop, DecordDecode, DecordInit,
FormatShape, Resize)
from torchvision import transforms
def extract_clip_feature_single_video_fps(
video_path: str,
clip_ckpt_path: str = 'ViT-L-14.pt',
device: str = 'cuda'):
class SampleFrames1FPS(object):
'''Sample frames at 1 fps.
Required Keys:
- total_frames
- start_index
- avg_fps
Added Keys:
- frame_interval
- frame_inds
- num_clips
'''
def transform(self, video_info: dict) -> dict:
video_info['frame_inds'] = np.arange(
video_info['start_index'],
video_info['total_frames'],
video_info['avg_fps'],
dtype=int) # np.arange(start, stop, step, dtype)
video_info['frame_interval'] = 1
video_info['num_clips'] = len(video_info['frame_inds'])
return video_info
class SampleFrames5FPS(object):
'''Sample frames at 5 fps.
Required Keys:
- total_frames
- start_index
- avg_fps
Added Keys:
- frame_interval
- frame_inds
- num_clips
'''
def transform(self, video_info: dict) -> dict:
video_info['frame_inds'] = np.arange(
video_info['start_index'],
video_info['total_frames'],
video_info['avg_fps'] // 5,
dtype=int)
video_info['frame_interval'] = 1
video_info['num_clips'] = len(video_info['frame_inds'])
return video_info
video_info = {'filename': video_path, 'start_index': 0}
video_processors = [
DecordInit(),
SampleFrames1FPS(),
DecordDecode(),
Resize(scale=(-1, 224)),
CenterCrop(crop_size=224),
FormatShape(input_format='NCHW'),
]
# decode video to imgs
for processor in video_processors:
video_info = processor.transform(video_info)
imgs = torch.from_numpy(video_info['imgs']) # uint8 img tensor
imgs_transforms = transforms.Compose([
transforms.ConvertImageDtype(dtype=torch.float32),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
inplace=False)
])
# uint8 -> float, then normalize
imgs = imgs_transforms(imgs).to(device)
# load model
clip_model, _ = clip.load(clip_ckpt_path, device)
# encode imgs get features
with torch.no_grad():
video_feat = clip_model.encode_image(imgs)
return video_feat, video_info
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
video_names = [
'cook.mp4', 'latex.mp4', 'nba.mp4', 'temple_of_heaven.mp4',
'south_pole.mp4', 'tv_series.mp4', 'formula_one.mp4', 'make-up.mp4',
'police.mp4'
]
video_dir = '/mnt/petrelfs/wangyiqin/vid_cap/examples/videos/'
for video_name in video_names:
video_feat = extract_clip_feature_single_video_fps(
video_path=video_dir + video_name,
clip_ckpt_path='ViT-L-14.pt',
device=device)
video_feat = video_feat.cpu()
# compress to one dimension
video_feat = video_feat.numpy()
np.save('clip_features/20/' + video_name[:-4] + '.npy', video_feat)
print(video_feat.shape)
print(video_name + ' DONE')