File size: 2,686 Bytes
8b79d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import av
import os
import pims
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Image
class VideoReader(Dataset):
def __init__(self, path, transform=None):
self.video = pims.PyAVVideoReader(path)
self.rate = self.video.frame_rate
self.transform = transform
@property
def frame_rate(self):
return self.rate
def __len__(self):
return len(self.video)
def __getitem__(self, idx):
frame = self.video[idx]
frame = Image.fromarray(np.asarray(frame))
if self.transform is not None:
frame = self.transform(frame)
return frame
class VideoWriter:
def __init__(self, path, frame_rate, bit_rate=1000000):
self.container = av.open(path, mode='w')
self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}')
self.stream.pix_fmt = 'yuv420p'
self.stream.bit_rate = bit_rate
def write(self, frames):
# frames: [T, C, H, W]
self.stream.width = frames.size(3)
self.stream.height = frames.size(2)
if frames.size(1) == 1:
frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
for t in range(frames.shape[0]):
frame = frames[t]
frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
self.container.mux(self.stream.encode(frame))
def close(self):
self.container.mux(self.stream.encode())
self.container.close()
class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
if self.transform is not None:
return self.transform(img)
return img
class ImageSequenceWriter:
def __init__(self, path, extension='jpg'):
self.path = path
self.extension = extension
self.counter = 0
os.makedirs(path, exist_ok=True)
def write(self, frames):
# frames: [T, C, H, W]
for t in range(frames.shape[0]):
to_pil_image(frames[t]).save(os.path.join(
self.path, str(self.counter).zfill(4) + '.' + self.extension))
self.counter += 1
def close(self):
pass
|