chiyoi's picture
update
139dd3e
raw history blame
No virus
2.19 kB
from pathlib import Path
import random
from typing import Literal
import cv2
import numpy as np
import tensorflow as tf
training_ratio = 0.7
validation_ratio = 0.02
num_frames = 8
frame_step = 15
frame_size = (224, 224)
def format_frame(frame):
frame = tf.image.convert_image_dtype(frame, tf.float32)
frame = tf.image.resize_with_pad(frame, *frame_size)
return frame
def pick_frames(video: str):
capture = cv2.VideoCapture(video)
if not capture.isOpened(): raise ValueError('Video file could not be opened.')
total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT)
need_frames = 1 + (num_frames - 1) * frame_step
if need_frames <= total_frames:
start = random.randint(0, total_frames - need_frames + 1)
capture.set(cv2.CAP_PROP_POS_FRAMES, start)
frames = []
for _ in range(num_frames):
for _ in range(frame_step):
ok, frame = capture.read()
if ok: frames.append(format_frame(frame))
else: frames.append(np.zeros(frame_size + (3,)))
capture.release()
frames = np.array(frames)
frames = frames[..., [2, 1, 0]]
return frames
def Data(data_dir: str):
data_dir = Path(data_dir)
return {
'training':{
a.name: (lambda ps: ps[:int(len(ps) * training_ratio)])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
'validation': {
a.name: (lambda ps: ps[
int(len(ps) * training_ratio) :
int(len(ps) * (training_ratio + validation_ratio))
])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
}
def ClassMapping(data_dir: str):
data_dir = Path(data_dir)
id_to_name = sorted([x.name for x in data_dir.iterdir()])
name_to_id = {
name: i
for i, name in enumerate(id_to_name)
}
return (id_to_name, name_to_id)
def FrameGenerator(data_dir: str, split: Literal['training', 'validation']):
_, name_to_id = ClassMapping(data_dir)
data = Data(data_dir)
def generator():
pairs = [
(video, class_name)
for class_name, videos in data[split].items()
for video in videos
]
random.shuffle(pairs)
for video, class_name in pairs:
frames = pick_frames(video)
label = name_to_id[class_name]
yield frames, label
return generator