Spaces:
Running
Running
| '''Module for loading the fakeavceleb dataset from tfrecord format''' | |
| import numpy as np | |
| import tensorflow as tf | |
| from data.augmentation_utils import create_frame_transforms, create_spec_transforms | |
| FEATURE_DESCRIPTION = { | |
| 'video_path': tf.io.FixedLenFeature([], tf.string), | |
| 'image/encoded': tf.io.FixedLenFeature([], tf.string), | |
| 'clip/label/index': tf.io.FixedLenFeature([], tf.int64), | |
| 'clip/label/text': tf.io.FixedLenFeature([], tf.string), | |
| 'WAVEFORM/feature/floats': tf.io.FixedLenFeature([], tf.string) | |
| } | |
| def _parse_function(example_proto): | |
| #Parse the input `tf.train.Example` proto using the dictionary above. | |
| example = tf.io.parse_single_example(example_proto, FEATURE_DESCRIPTION) | |
| video_path = example['video_path'] | |
| video = tf.io.decode_raw(example['image/encoded'], tf.int8) | |
| spectrogram = tf.io.decode_raw(example['WAVEFORM/feature/floats'], tf.float32) | |
| label = example["clip/label/text"] | |
| label_map = example["clip/label/index"] | |
| return video, spectrogram, label_map | |
| def decode_inputs(video, spectrogram, label_map): | |
| '''Decode tensors to arrays with desired shape''' | |
| frame = tf.reshape(video, [10, 3, 256, 256]) | |
| frame = frame[0] / 255 #Pick the first frame and normalize it. | |
| # frame = tf.cast(frame, tf.float32) | |
| label_map = tf.expand_dims(label_map, axis = 0) | |
| sample = {'video_reshaped': frame, 'spectrogram': spectrogram, 'label_map': label_map} | |
| return sample | |
| def decode_train_inputs(video, spectrogram, label_map): | |
| #Data augmentation for spectograms | |
| spectrogram_shape = spectrogram.shape | |
| spec_augmented = tf.py_function(aug_spec_fn, [spectrogram], tf.float32) | |
| spec_augmented.set_shape(spectrogram_shape) | |
| frame = tf.reshape(video, [10, 256, 256, 3]) | |
| frame = frame[0] #Pick the first frame. | |
| frame = frame / 255 #Normalize tensor. | |
| frame_augmented = tf.py_function(aug_img_fn, [frame], tf.uint8) | |
| # frame_augmented.set_shape(frame_shape) | |
| frame_augmented.set_shape([3, 256, 256]) | |
| label_map = tf.expand_dims(label_map, axis = 0) | |
| augmented_sample = {'video_reshaped': frame_augmented, 'spectrogram': spec_augmented, 'label_map': label_map} | |
| return augmented_sample | |
| def aug_img_fn(frame): | |
| frame = frame.numpy().astype(np.uint8) | |
| frame_data = {'image': frame} | |
| aug_frame_data = create_frame_transforms(**frame_data) | |
| aug_img = aug_frame_data['image'] | |
| aug_img = aug_img.transpose(2, 0, 1) | |
| return aug_img | |
| def aug_spec_fn(spec): | |
| spec = spec.numpy() | |
| spec_data = {'spec': spec} | |
| aug_spec_data = create_spec_transforms(**spec_data) | |
| aug_spec = aug_spec_data['spec'] | |
| return aug_spec | |
| class FakeAVCelebDatasetTrain: | |
| def __init__(self, args): | |
| self.args = args | |
| self.samples = self.load_features_from_tfrec() | |
| def load_features_from_tfrec(self): | |
| '''Loads raw features from a tfrecord file and returns them as raw inputs''' | |
| ds = tf.io.matching_files(self.args.data_dir) | |
| files = tf.random.shuffle(ds) | |
| shards = tf.data.Dataset.from_tensor_slices(files) | |
| dataset = shards.interleave(tf.data.TFRecordDataset) | |
| dataset = dataset.shuffle(buffer_size=100) | |
| dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE) | |
| dataset = dataset.map(decode_train_inputs, num_parallel_calls = tf.data.AUTOTUNE) | |
| dataset = dataset.padded_batch(batch_size = self.args.batch_size) | |
| return dataset | |
| def __len__(self): | |
| self.samples = self.load_features_from_tfrec(self.args.data_dir) | |
| cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1) | |
| cnt = cnt.numpy() | |
| return cnt | |
| class FakeAVCelebDatasetVal: | |
| def __init__(self, args): | |
| self.args = args | |
| self.samples = self.load_features_from_tfrec() | |
| def load_features_from_tfrec(self): | |
| '''Loads raw features from a tfrecord file and returns them as raw inputs''' | |
| ds = tf.io.matching_files(self.args.data_dir) | |
| files = tf.random.shuffle(ds) | |
| shards = tf.data.Dataset.from_tensor_slices(files) | |
| dataset = shards.interleave(tf.data.TFRecordDataset) | |
| dataset = dataset.shuffle(buffer_size=100) | |
| dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE) | |
| dataset = dataset.map(decode_inputs, num_parallel_calls = tf.data.AUTOTUNE) | |
| dataset = dataset.padded_batch(batch_size = self.args.batch_size) | |
| return dataset | |
| def __len__(self): | |
| self.samples = self.load_features_from_tfrec(self.args.data_dir) | |
| cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1) | |
| cnt = cnt.numpy() | |
| return cnt |