|
from abc import ABC, abstractclassmethod, abstractmethod |
|
import glob |
|
import math |
|
import os |
|
from typing import Dict |
|
from typing_extensions import dataclass_transform |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from tqdm.auto import tqdm |
|
|
|
|
|
def _bytes_feature(value): |
|
"""Returns a bytes_list from a string / byte.""" |
|
if isinstance(value, type(tf.constant(0))): |
|
value = value.numpy() |
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) |
|
|
|
|
|
def _float_feature(value): |
|
"""Returns a floast_list from a float / double.""" |
|
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) |
|
|
|
|
|
def _int64_feature(value): |
|
"""Returns an int64_list from a bool / enum / int / uint.""" |
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) |
|
|
|
|
|
def serialize_array(array): |
|
array = tf.io.serialize_tensor(array) |
|
return array |
|
|
|
|
|
class Dataset(ABC): |
|
def __init__(self, dataset_path: str): |
|
self.dataset_path = dataset_path |
|
|
|
@classmethod |
|
def _parse_single_element(cls, element) -> tf.train.Example: |
|
|
|
features = tf.train.Features(feature=cls._get_features(element)) |
|
|
|
return tf.train.Example(features=features) |
|
|
|
@abstractclassmethod |
|
def _get_features(cls, element) -> Dict[str, tf.train.Feature]: |
|
pass |
|
|
|
@abstractclassmethod |
|
def _parse_tfr_element(cls, element): |
|
pass |
|
|
|
@classmethod |
|
def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str): |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
|
|
|
|
single_file_name = cls.__write_to_single_tfr(data, out_dir, filename) |
|
|
|
|
|
number_splits = cls.__get_number_splits(single_file_name) |
|
|
|
if number_splits > 1: |
|
os.remove(single_file_name) |
|
cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits) |
|
|
|
@classmethod |
|
def __write_to_multiple_tfr( |
|
cls, data: np.array, out_dir: str, filename: str, n_splits: int |
|
): |
|
|
|
file_count = 0 |
|
|
|
max_files = math.ceil(data.shape[0] / n_splits) |
|
|
|
print(f"Creating {n_splits} files with {max_files} elements each.") |
|
|
|
for i in tqdm(range(n_splits)): |
|
current_shard_name = os.path.join( |
|
out_dir, |
|
f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}", |
|
) |
|
writer = tf.io.TFRecordWriter(current_shard_name) |
|
|
|
current_shard_count = 0 |
|
while current_shard_count < max_files: |
|
|
|
index = i * max_files + current_shard_count |
|
if index >= len( |
|
data |
|
): |
|
break |
|
|
|
current_element = data[index] |
|
|
|
|
|
out = cls._parse_single_element(element=current_element) |
|
|
|
writer.write(out.SerializeToString()) |
|
current_shard_count += 1 |
|
file_count += 1 |
|
|
|
writer.close() |
|
print(f"\nWrote {file_count} elements to TFRecord") |
|
return file_count |
|
|
|
@classmethod |
|
def __get_number_splits(cls, filename: str): |
|
target_size = 100 * 1024 * 1024 |
|
|
|
single_file_size = os.path.getsize(filename) |
|
number_splits = math.ceil(single_file_size / target_size) |
|
return number_splits |
|
|
|
@classmethod |
|
def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str): |
|
|
|
current_path_name = os.path.join( |
|
out_dir, |
|
f"{filename}.tfrecords-0-of-1", |
|
) |
|
|
|
writer = tf.io.TFRecordWriter(current_path_name) |
|
for element in tqdm(data): |
|
writer.write(cls._parse_single_element(element).SerializeToString()) |
|
writer.close() |
|
|
|
return current_path_name |
|
|
|
def load(self) -> tf.data.TFRecordDataset: |
|
path = self.dataset_path |
|
dataset = None |
|
|
|
if os.path.isdir(path): |
|
dataset = self._load_folder(path) |
|
elif os.path.isfile(path): |
|
dataset = self._load_file(path) |
|
else: |
|
raise ValueError(f"Path {path} is not a valid file or folder.") |
|
|
|
dataset = dataset.map(self._parse_tfr_element) |
|
return dataset |
|
|
|
def _load_file(self, path) -> tf.data.TFRecordDataset: |
|
return tf.data.TFRecordDataset(path) |
|
|
|
def _load_folder(self, path) -> tf.data.TFRecordDataset: |
|
|
|
return tf.data.TFRecordDataset( |
|
glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True) |
|
) |
|
|
|
|
|
class VideoDataset(Dataset): |
|
@classmethod |
|
def _get_features(cls, element) -> Dict[str, tf.train.Feature]: |
|
return { |
|
"frames": _int64_feature(element.shape[0]), |
|
"height": _int64_feature(element.shape[1]), |
|
"width": _int64_feature(element.shape[2]), |
|
"depth": _int64_feature(element.shape[3]), |
|
"raw_video": _bytes_feature(serialize_array(element)), |
|
} |
|
|
|
@classmethod |
|
def _parse_tfr_element(cls, element): |
|
|
|
data = { |
|
"frames": tf.io.FixedLenFeature([], tf.int64), |
|
"height": tf.io.FixedLenFeature([], tf.int64), |
|
"width": tf.io.FixedLenFeature([], tf.int64), |
|
"raw_video": tf.io.FixedLenFeature([], tf.string), |
|
"depth": tf.io.FixedLenFeature([], tf.int64), |
|
} |
|
|
|
content = tf.io.parse_single_example(element, data) |
|
|
|
frames = content["frames"] |
|
height = content["height"] |
|
width = content["width"] |
|
depth = content["depth"] |
|
raw_video = content["raw_video"] |
|
|
|
|
|
feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8) |
|
feature = tf.reshape(feature, shape=[frames, height, width, depth]) |
|
return feature |
|
|
|
|
|
class ImageDataset(Dataset): |
|
@classmethod |
|
def _get_features(cls, element) -> Dict[str, tf.train.Feature]: |
|
return { |
|
"height": _int64_feature(element.shape[0]), |
|
"width": _int64_feature(element.shape[1]), |
|
"depth": _int64_feature(element.shape[2]), |
|
"raw_image": _bytes_feature(serialize_array(element)), |
|
} |
|
|
|
@classmethod |
|
def _parse_tfr_element(cls, element): |
|
|
|
data = { |
|
"height": tf.io.FixedLenFeature([], tf.int64), |
|
"width": tf.io.FixedLenFeature([], tf.int64), |
|
"raw_image": tf.io.FixedLenFeature([], tf.string), |
|
"depth": tf.io.FixedLenFeature([], tf.int64), |
|
} |
|
|
|
content = tf.io.parse_single_example(element, data) |
|
|
|
height = content["height"] |
|
width = content["width"] |
|
depth = content["depth"] |
|
raw_image = content["raw_image"] |
|
|
|
|
|
feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8) |
|
feature = tf.reshape(feature, shape=[height, width, depth]) |
|
return feature |
|
|