Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import random | |
from typing import Literal | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
TRAINING_RATIO = 0.1 | |
VALIDATION_RATIO = 0.01 | |
def format_frames(frame, output_size): | |
frame = tf.image.convert_image_dtype(frame, tf.float32) | |
frame = tf.image.resize_with_pad(frame, *output_size) | |
return frame | |
def frames_from_video_file(video_path: str, n_frames: int, output_size=(256, 256), frame_step=15): | |
capture = cv2.VideoCapture(video_path) | |
if not capture.isOpened(): raise ValueError('Video file could not be opened.') | |
total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
need_frames = 1 + (n_frames - 1) * frame_step | |
if need_frames <= total_frames: | |
start = random.randint(0, total_frames - need_frames + 1) | |
capture.set(cv2.CAP_PROP_POS_FRAMES, start) | |
frames = [] | |
for _ in range(n_frames - 1): | |
for _ in range(frame_step): | |
ok, frame = capture.read() | |
if ok: | |
frames.append(format_frames(frame, output_size)) | |
else: | |
frames.append(np.zeros((output_size[0], output_size[1], 3))) | |
capture.release() | |
frames = np.array(frames) | |
frames = frames[..., [2, 1, 0]] | |
return frames | |
def Data(data_dir: Path): | |
return { | |
'training':{ | |
a.name: (lambda ps: ps[:int(len(ps) * TRAINING_RATIO)])([x for x in a.iterdir()]) | |
for a in data_dir.iterdir() | |
}, | |
'validation': { | |
a.name: (lambda ps: ps[ | |
int(len(ps) * TRAINING_RATIO) : | |
int(len(ps) * (TRAINING_RATIO + VALIDATION_RATIO)) | |
])([x for x in a.iterdir()]) | |
for a in data_dir.iterdir() | |
}, | |
} | |
def frame_generator(data_dir: Path, n_frames: int, split: Literal['training', 'validation']): | |
class_names = sorted([x.name for x in data_dir.iterdir()]) | |
class_ids_for_name = { | |
name: i | |
for i, name in enumerate(class_names) | |
} | |
data = Data(data_dir) | |
def generator(): | |
pairs = [ | |
(path, name) | |
for name, paths in data[split].items() | |
for path in paths | |
] | |
random.shuffle(pairs) | |
for path, name in pairs: | |
video_frames = frames_from_video_file(str(path), n_frames) | |
label = class_ids_for_name[name] | |
yield video_frames, label | |
return generator | |
def total_steps(data_dir: Path): | |
data = Data(data_dir) | |
size = lambda d: sum([len(x) for x in d.values()]) | |
return size(data['training']), size(data['validation']) | |