from transformers import ViTFeatureExtractor import torchvision.transforms.functional as fn import torch as th def video2image(video, feature_extractor_name): feature_extractor = ViTFeatureExtractor.from_pretrained( feature_extractor_name ) vid = th.permute(video, (3, 0, 1, 2)) samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long) vid = vid[:, samp, :, :] im_l = list() for i in range(vid.shape[1]): im_l.append(vid[:, i, :, :]) inputs = feature_extractor(im_l, return_tensors="pt") inputs = inputs['pixel_values'] im_h = list() for i in range(7): im_v = th.cat((inputs[0+i*7, :, :, :], inputs[1+i*7, :, :, :], inputs[2+i*7, :, :, :], inputs[3+i*7, :, :, :], inputs[4+i*7, :, :, :], inputs[5+i*7, :, :, :], inputs[6+i*7, :, :, :]), 2) im_h.append(im_v) resize = fn.resize(th.cat(im_h, 1), size=[224]) return resize