|
|
|
|
|
import unittest |
|
|
|
import pytest |
|
from pytorchvideo.data.frame_video import FrameVideo |
|
from utils import temp_frame_video |
|
|
|
|
|
class TestFrameVideo(unittest.TestCase): |
|
def test_frame_video_works(self): |
|
frame_names = [f"{str(i)}.png" for i in range(3)] |
|
with temp_frame_video(frame_names) as (f_name, data): |
|
frame_paths = [f_name / x for x in frame_names] |
|
test_video = FrameVideo.from_frame_paths(frame_paths) |
|
expected_duration = ( |
|
0.1 |
|
) |
|
self.assertEqual(test_video.duration, expected_duration) |
|
|
|
|
|
clip = test_video.get_clip(0, 0.1) |
|
frames, indices = clip["video"], clip["frame_indices"] |
|
self.assertTrue(frames.equal(data)) |
|
self.assertEqual(indices, [0, 1, 2]) |
|
|
|
|
|
clip = test_video.get_clip(0, 0.1, lambda lst: lst[1:2]) |
|
frames, indices = clip["video"], clip["frame_indices"] |
|
self.assertTrue(frames.equal(data[:, 1:2])) |
|
self.assertEqual(indices, [1]) |
|
|
|
|
|
clip = test_video.get_clip(0, 0.066) |
|
frames, indices = clip["video"], clip["frame_indices"] |
|
self.assertTrue(frames.equal(data[:, :2])) |
|
self.assertEqual(indices, [0, 1]) |
|
|
|
|
|
result = test_video.get_clip(3, 5) |
|
self.assertEqual(result, None) |
|
|
|
def test_open_video_failure(self): |
|
test_video = FrameVideo.from_frame_paths(["non_existent_file.txt"]) |
|
with pytest.raises(Exception): |
|
test_video.get_clip(0, 0.01) |
|
|
|
def test_empty_frames_failure(self): |
|
with pytest.raises(AssertionError): |
|
FrameVideo.from_frame_paths([]) |
|
|