Spaces:
Sleeping
Sleeping
File size: 2,191 Bytes
139dd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from pathlib import Path
import random
from typing import Literal
import cv2
import numpy as np
import tensorflow as tf
training_ratio = 0.7
validation_ratio = 0.02
num_frames = 8
frame_step = 15
frame_size = (224, 224)
def format_frame(frame):
frame = tf.image.convert_image_dtype(frame, tf.float32)
frame = tf.image.resize_with_pad(frame, *frame_size)
return frame
def pick_frames(video: str):
capture = cv2.VideoCapture(video)
if not capture.isOpened(): raise ValueError('Video file could not be opened.')
total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT)
need_frames = 1 + (num_frames - 1) * frame_step
if need_frames <= total_frames:
start = random.randint(0, total_frames - need_frames + 1)
capture.set(cv2.CAP_PROP_POS_FRAMES, start)
frames = []
for _ in range(num_frames):
for _ in range(frame_step):
ok, frame = capture.read()
if ok: frames.append(format_frame(frame))
else: frames.append(np.zeros(frame_size + (3,)))
capture.release()
frames = np.array(frames)
frames = frames[..., [2, 1, 0]]
return frames
def Data(data_dir: str):
data_dir = Path(data_dir)
return {
'training':{
a.name: (lambda ps: ps[:int(len(ps) * training_ratio)])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
'validation': {
a.name: (lambda ps: ps[
int(len(ps) * training_ratio) :
int(len(ps) * (training_ratio + validation_ratio))
])([x for x in a.iterdir()])
for a in data_dir.iterdir()
},
}
def ClassMapping(data_dir: str):
data_dir = Path(data_dir)
id_to_name = sorted([x.name for x in data_dir.iterdir()])
name_to_id = {
name: i
for i, name in enumerate(id_to_name)
}
return (id_to_name, name_to_id)
def FrameGenerator(data_dir: str, split: Literal['training', 'validation']):
_, name_to_id = ClassMapping(data_dir)
data = Data(data_dir)
def generator():
pairs = [
(video, class_name)
for class_name, videos in data[split].items()
for video in videos
]
random.shuffle(pairs)
for video, class_name in pairs:
frames = pick_frames(video)
label = name_to_id[class_name]
yield frames, label
return generator
|