Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import random | |
from typing import Literal | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from configurations import * | |
def format_frame(frame): | |
frame = tf.image.convert_image_dtype(frame, tf.float32) | |
frame = tf.image.resize_with_pad(frame, *frame_size) | |
return frame | |
def pick_frames(video: str): | |
capture = cv2.VideoCapture(video) | |
if not capture.isOpened(): raise ValueError('Video file could not be opened.') | |
total_frames = capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
need_frames = 1 + (num_frames - 1) * frame_step | |
if need_frames <= total_frames: | |
start = random.randint(0, total_frames - need_frames + 1) | |
capture.set(cv2.CAP_PROP_POS_FRAMES, start) | |
frames = [] | |
for _ in range(num_frames): | |
for _ in range(frame_step): | |
ok, frame = capture.read() | |
if ok: frames.append(format_frame(frame)) | |
else: frames.append(np.zeros(frame_size + (3,))) | |
capture.release() | |
frames = np.array(frames) | |
frames = frames[..., [2, 1, 0]] | |
return frames | |
def Data(): | |
data_dir_path = Path(data_dir) | |
return { | |
'training': { | |
a.name: ( | |
lambda ps: ps[ | |
:int(len(ps) * training_ratio)])( | |
[x for x in a.iterdir()]) | |
for a in data_dir_path.iterdir()}, | |
'validation': { | |
a.name: ( | |
lambda ps: ps[ | |
int(len(ps) * training_ratio): | |
int(len(ps) * (training_ratio + validation_ratio))])( | |
[x for x in a.iterdir()]) | |
for a in data_dir_path.iterdir()}, | |
'testing': { | |
a.name: ( | |
lambda ps: ps[ | |
int(len(ps) * (training_ratio + validation_ratio)):])( | |
[x for x in a.iterdir()]) | |
for a in data_dir_path.iterdir()}, | |
} | |
def FrameGenerator(split: Literal['training', 'validation']): | |
data = Data() | |
def generator(): | |
pairs = [ | |
(str(video), class_name) | |
for class_name, videos in data[split].items() | |
for video in videos | |
] | |
random.shuffle(pairs) | |
for video, class_name in pairs: | |
frames = pick_frames(video) | |
label = name_to_id[class_name] | |
yield frames, label | |
return generator | |