from pathlib import Path import random from typing import Literal import cv2 import numpy as np import tensorflow as tf TRAINING_RATIO = 0.1 VALIDATION_RATIO = 0.01 def format_frames(frame, output_size): frame = tf.image.convert_image_dtype(frame, tf.float32) frame = tf.image.resize_with_pad(frame, *output_size) return frame def frames_from_video_file(video_path: str, n_frames: int, output_size=(256, 256), frame_step=15): capture = cv2.VideoCapture(video_path) if not capture.isOpened(): raise ValueError('Video file could not be opened.') total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT) need_frames = 1 + (n_frames - 1) * frame_step if need_frames <= total_frames: start = random.randint(0, total_frames - need_frames + 1) capture.set(cv2.CAP_PROP_POS_FRAMES, start) frames = [] for _ in range(n_frames - 1): for _ in range(frame_step): ok, frame = capture.read() if ok: frames.append(format_frames(frame, output_size)) else: frames.append(np.zeros((output_size[0], output_size[1], 3))) capture.release() frames = np.array(frames) frames = frames[..., [2, 1, 0]] return frames def Data(data_dir: Path): return { 'training':{ a.name: (lambda ps: ps[:int(len(ps) * TRAINING_RATIO)])([x for x in a.iterdir()]) for a in data_dir.iterdir() }, 'validation': { a.name: (lambda ps: ps[ int(len(ps) * TRAINING_RATIO) : int(len(ps) * (TRAINING_RATIO + VALIDATION_RATIO)) ])([x for x in a.iterdir()]) for a in data_dir.iterdir() }, } def frame_generator(data_dir: Path, n_frames: int, split: Literal['training', 'validation']): class_names = sorted([x.name for x in data_dir.iterdir()]) class_ids_for_name = { name: i for i, name in enumerate(class_names) } data = Data(data_dir) def generator(): pairs = [ (path, name) for name, paths in data[split].items() for path in paths ] random.shuffle(pairs) for path, name in pairs: video_frames = frames_from_video_file(str(path), n_frames) label = class_ids_for_name[name] yield video_frames, label return generator def total_steps(data_dir: Path): data = Data(data_dir) size = lambda d: sum([len(x) for x in d.values()]) return size(data['training']), size(data['validation'])