search-in-video / utils.py
Armen Gabrielyan
add batch generation
cde7ed6
from transformers import ViTFeatureExtractor
import torchvision.transforms.functional as fn
import torch as th
def video2image(video, feature_extractor_name):
feature_extractor = ViTFeatureExtractor.from_pretrained(
feature_extractor_name
)
vid = th.permute(video, (3, 0, 1, 2))
samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
vid = vid[:, samp, :, :]
im_l = list()
for i in range(vid.shape[1]):
im_l.append(vid[:, i, :, :])
inputs = feature_extractor(im_l, return_tensors="pt")
inputs = inputs['pixel_values']
im_h = list()
for i in range(7):
im_v = th.cat((inputs[0+i*7, :, :, :],
inputs[1+i*7, :, :, :],
inputs[2+i*7, :, :, :],
inputs[3+i*7, :, :, :],
inputs[4+i*7, :, :, :],
inputs[5+i*7, :, :, :],
inputs[6+i*7, :, :, :]), 2)
im_h.append(im_v)
resize = fn.resize(th.cat(im_h, 1), size=[224])
return resize