video2music / utilities /video_loader.py
kjysmu's picture
add files
4e46a55
import torch as th
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import ffmpeg
class VideoLoader(Dataset):
def __init__(
self,
fileList = [],
framerate=1,
size=112,
centercrop=False,
):
#self.csv = pd.read_csv(csv)
self.fileList = fileList
self.centercrop = centercrop
self.size = size
self.framerate = framerate
def __len__(self):
return len(self.fileList)
def _get_video_dim(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
return height, width
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, idx):
video_path = self.fileList[idx]
output_file = video_path[:video_path.rfind(".")] + ".npy"
#video_path = self.csv['video_path'].values[idx]
#output_file = self.csv['feature_path'].values[idx]
if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
print('Decoding video: {}'.format(video_path))
try:
h, w = self._get_video_dim(video_path)
except:
print('ffprobe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
height, width = self._get_output_dim(h, w)
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=self.framerate)
.filter('scale', width, height)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
video = video.permute(0, 3, 1, 2)
else:
video = th.zeros(1)
return {'video': video, 'input': video_path, 'output': output_file}