Multimodal_Demo / models /clip_model.py
Sirus1's picture
Duplicate from TencentARC/VLog
6f6830f
raw
history blame
No virus
1.97 kB
import os
import cv2
import pdb
import torch
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPVisionModelWithProjection
from transformers import logging
logging.set_verbosity_error()
class FeatureExtractor():
def __init__(self, args):
self.device = args.feature_extractor_device
self.beta = args.beta
self.processor = CLIPProcessor.from_pretrained(args.feature_extractor)
self.model = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor).to(self.device)
self.data_dir = args.data_dir
self.tmp_dir = args.tmp_dir
def __call__(self, video_path, video_id):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_length = frame_count / fps
sample_rate = int(fps) * self.beta
save_path = os.path.join(self.tmp_dir, video_id + '.npz')
if os.path.exists(save_path):
data = np.load(save_path)
clip_features = data['features']
return clip_features, video_length
clip_features = []
print("Extract the clip feature.")
while True:
ret, frame = cap.read()
if not ret:
break
if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0:
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
inputs = self.processor(images=image, return_tensors="pt").pixel_values
inputs = inputs.to(self.device)
with torch.no_grad():
feat = self.model(inputs)['image_embeds']
clip_features.append(feat.cpu().numpy())
print("Finished.")
clip_features = np.concatenate(clip_features, axis=0)
np.savez_compressed(save_path, features=clip_features)
return clip_features, video_length