chiyoi's picture
working
804b63c
raw
history blame
No virus
2.55 kB
from pathlib import Path
import random
from typing import Literal
import cv2
import numpy as np
import tensorflow as tf
TRAINING_RATIO = 0.1
VALIDATION_RATIO = 0.01
def format_frames(frame, output_size):
frame = tf.image.convert_image_dtype(frame, tf.float32)
frame = tf.image.resize_with_pad(frame, *output_size)
return frame
def frames_from_video_file(video_path: str, n_frames: int, output_size=(256, 256), frame_step=15):
capture = cv2.VideoCapture(video_path)
if not capture.isOpened(): raise ValueError('Video file could not be opened.')
total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT)
need_frames = 1 + (n_frames - 1) * frame_step
if need_frames <= total_frames:
start = random.randint(0, total_frames - need_frames + 1)
capture.set(cv2.CAP_PROP_POS_FRAMES, start)
frames = []
for _ in range(n_frames - 1):
for _ in range(frame_step):
ok, frame = capture.read()
if ok:
frames.append(format_frames(frame, output_size))
else:
frames.append(np.zeros((output_size[0], output_size[1], 3)))
capture.release()
frames = np.array(frames)
frames = frames[..., [2, 1, 0]]
return frames
def Data(data_dir: Path):
return {
'training':{
a.name: (lambda ps: ps[:int(len(ps) * TRAINING_RATIO)])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
'validation': {
a.name: (lambda ps: ps[
int(len(ps) * TRAINING_RATIO) :
int(len(ps) * (TRAINING_RATIO + VALIDATION_RATIO))
])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
}
def frame_generator(data_dir: Path, n_frames: int, split: Literal['training', 'validation']):
class_names = sorted([x.name for x in data_dir.iterdir()])
class_ids_for_name = {
name: i
for i, name in enumerate(class_names)
}
data = Data(data_dir)
def generator():
pairs = [
(path, name)
for name, paths in data[split].items()
for path in paths
]
random.shuffle(pairs)
for path, name in pairs:
video_frames = frames_from_video_file(str(path), n_frames)
label = class_ids_for_name[name]
yield video_frames, label
return generator
def total_steps(data_dir: Path):
data = Data(data_dir)
size = lambda d: sum([len(x) for x in d.values()])
return size(data['training']), size(data['validation'])