Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import numpy as np | |
import ffmpeg | |
import math | |
from run_on_video import clip | |
class ClipFeatureExtractor: | |
def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"): | |
self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop) | |
print("Loading CLIP models") | |
self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False) | |
self.tokenizer = clip.tokenize | |
self.video_preprocessor = Preprocessing() | |
self.device = device | |
def encode_video(self, video_path: str, bsz=60): | |
video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3) | |
video_frames = self.video_preprocessor(video_frames) | |
n_frames = len(video_frames) | |
n_batch = int(math.ceil(n_frames / bsz)) | |
video_features = [] | |
for i in range(n_batch): | |
st_idx = i * bsz | |
ed_idx = (i+1) * bsz | |
_video_frames = video_frames[st_idx:ed_idx].to(self.device) | |
_video_features = self.clip_extractor.encode_image(_video_frames) | |
video_features.append(_video_features) | |
video_features = torch.cat(video_features, dim=0) | |
return video_features # (T=#frames, d) torch tensor | |
def encode_text(self, text_list, bsz=60): | |
n_text = len(text_list) | |
n_batch = int(math.ceil(n_text / bsz)) | |
text_features = [] | |
for i in range(n_batch): | |
st_idx = i * bsz | |
ed_idx = (i+1) * bsz | |
encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device) | |
output = self.clip_extractor.encode_text(encoded_texts) | |
valid_lengths = (encoded_texts != 0).sum(1).tolist() | |
batch_last_hidden_states = output["last_hidden_state"] | |
for j, valid_len in enumerate(valid_lengths): | |
text_features.append(batch_last_hidden_states[j, :valid_len]) | |
return text_features # List([L_j, d]) torch tensor | |
def convert_to_float(frac_str): | |
try: | |
return float(frac_str) | |
except ValueError: | |
try: | |
num, denom = frac_str.split('/') | |
except ValueError: | |
return None | |
try: | |
leading, num = num.split(' ') | |
except ValueError: | |
return float(num) / float(denom) | |
if float(leading) < 0: | |
sign_mult = -1 | |
else: | |
sign_mult = 1 | |
return float(leading) + sign_mult * (float(num) / float(denom)) | |
class Normalize(object): | |
def __init__(self, mean, std): | |
self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1) | |
self.std = torch.FloatTensor(std).view(1, 3, 1, 1) | |
def __call__(self, tensor): | |
tensor = (tensor - self.mean) / (self.std + 1e-8) | |
return tensor | |
class Preprocessing(object): | |
def __init__(self): | |
self.norm = Normalize( | |
mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
def __call__(self, tensor): | |
tensor = tensor / 255.0 | |
tensor = self.norm(tensor) | |
return tensor | |
class VideoLoader: | |
"""Pytorch video loader. | |
Copied and modified from: | |
https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py | |
""" | |
def __init__( | |
self, | |
framerate=1/2, | |
size=224, | |
centercrop=True, | |
): | |
self.centercrop = centercrop | |
self.size = size | |
self.framerate = framerate | |
def _get_video_info(self, video_path): | |
probe = ffmpeg.probe(video_path) | |
video_stream = next((stream for stream in probe['streams'] | |
if stream['codec_type'] == 'video'), None) | |
width = int(video_stream['width']) | |
height = int(video_stream['height']) | |
fps = math.floor(convert_to_float(video_stream['avg_frame_rate'])) | |
try: | |
frames_length = int(video_stream['nb_frames']) | |
duration = float(video_stream['duration']) | |
except Exception: | |
frames_length, duration = -1, -1 | |
info = {"duration": duration, "frames_length": frames_length, | |
"fps": fps, "height": height, "width": width} | |
return info | |
def _get_output_dim(self, h, w): | |
if isinstance(self.size, tuple) and len(self.size) == 2: | |
return self.size | |
elif h >= w: | |
return int(h * self.size / w), self.size | |
else: | |
return self.size, int(w * self.size / h) | |
def read_video_from_file(self, video_path): | |
try: | |
info = self._get_video_info(video_path) | |
h, w = info["height"], info["width"] | |
except Exception: | |
print('ffprobe failed at: {}'.format(video_path)) | |
return {'video': torch.zeros(1), 'input': video_path, | |
'info': {}} | |
height, width = self._get_output_dim(h, w) | |
try: | |
duration = info["duration"] | |
fps = self.framerate | |
if duration > 0 and duration < 1/fps+0.1: | |
fps = 2/max(int(duration), 1) | |
print(duration, fps) | |
except Exception: | |
fps = self.framerate | |
cmd = ( | |
ffmpeg | |
.input(video_path) | |
.filter('fps', fps=fps) | |
.filter('scale', width, height) | |
) | |
if self.centercrop: | |
x = int((width - self.size) / 2.0) | |
y = int((height - self.size) / 2.0) | |
cmd = cmd.crop(x, y, self.size, self.size) | |
out, _ = ( | |
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') | |
.run(capture_stdout=True, quiet=True) | |
) | |
if self.centercrop and isinstance(self.size, int): | |
height, width = self.size, self.size | |
video = np.frombuffer(out, np.uint8).reshape( | |
[-1, height, width, 3]) | |
video = torch.from_numpy(video.astype('float32')) | |
video = video.permute(0, 3, 1, 2) | |
return video | |