File size: 4,730 Bytes
b6d5990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
'''Module for loading the fakeavceleb dataset from tfrecord format'''
import numpy as np
import tensorflow as tf
from data.augmentation_utils import create_frame_transforms, create_spec_transforms

FEATURE_DESCRIPTION = {
    'video_path': tf.io.FixedLenFeature([], tf.string), 
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'clip/label/index': tf.io.FixedLenFeature([], tf.int64),
    'clip/label/text': tf.io.FixedLenFeature([], tf.string), 
    'WAVEFORM/feature/floats': tf.io.FixedLenFeature([], tf.string)
}

@tf.function
def _parse_function(example_proto):

    #Parse the input `tf.train.Example` proto using the dictionary above.
    example = tf.io.parse_single_example(example_proto, FEATURE_DESCRIPTION)
    
    video_path = example['video_path']
    video = tf.io.decode_raw(example['image/encoded'], tf.int8)    
    spectrogram = tf.io.decode_raw(example['WAVEFORM/feature/floats'], tf.float32)
    
    label = example["clip/label/text"]
    label_map = example["clip/label/index"]
    
    return video, spectrogram, label_map

@tf.function
def decode_inputs(video, spectrogram, label_map):
    '''Decode tensors to arrays with desired shape'''
    frame = tf.reshape(video, [10, 3, 256, 256])
    frame = frame[0] / 255 #Pick the first frame and normalize it.
    # frame = tf.cast(frame, tf.float32)

    label_map = tf.expand_dims(label_map, axis = 0)
    
    sample = {'video_reshaped': frame, 'spectrogram': spectrogram, 'label_map': label_map}
    return sample


def decode_train_inputs(video, spectrogram, label_map):
    #Data augmentation for spectograms
    spectrogram_shape = spectrogram.shape
    spec_augmented = tf.py_function(aug_spec_fn, [spectrogram], tf.float32)
    spec_augmented.set_shape(spectrogram_shape)

    frame = tf.reshape(video, [10, 256, 256, 3])
    frame = frame[0] #Pick the first frame.
    frame = frame / 255 #Normalize tensor.

    frame_augmented = tf.py_function(aug_img_fn, [frame], tf.uint8)
    # frame_augmented.set_shape(frame_shape)

    frame_augmented.set_shape([3, 256, 256])
    label_map = tf.expand_dims(label_map, axis = 0)

    augmented_sample = {'video_reshaped': frame_augmented, 'spectrogram': spec_augmented, 'label_map': label_map}
    return augmented_sample


def aug_img_fn(frame):
  frame = frame.numpy().astype(np.uint8)
  frame_data = {'image': frame}
  aug_frame_data = create_frame_transforms(**frame_data)
  aug_img = aug_frame_data['image']
  aug_img = aug_img.transpose(2, 0, 1)
  return aug_img

def aug_spec_fn(spec):
  spec = spec.numpy()
  spec_data = {'spec': spec}
  aug_spec_data = create_spec_transforms(**spec_data)
  aug_spec = aug_spec_data['spec']
  return aug_spec


class FakeAVCelebDatasetTrain:

    def __init__(self, args):
        self.args = args
        self.samples = self.load_features_from_tfrec()

    def load_features_from_tfrec(self):
        '''Loads raw features from a tfrecord file and returns them as raw inputs'''
        ds = tf.io.matching_files(self.args.data_dir)
        files = tf.random.shuffle(ds)

        shards = tf.data.Dataset.from_tensor_slices(files)
        dataset = shards.interleave(tf.data.TFRecordDataset)
        dataset = dataset.shuffle(buffer_size=100)

        dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
        dataset = dataset.map(decode_train_inputs, num_parallel_calls = tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size = self.args.batch_size)
        return dataset
    
    
    def __len__(self):
        self.samples = self.load_features_from_tfrec(self.args.data_dir)
        cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
        cnt = cnt.numpy()
        return cnt

class FakeAVCelebDatasetVal:

    def __init__(self, args):
        self.args = args
        self.samples = self.load_features_from_tfrec()

    def load_features_from_tfrec(self):
        '''Loads raw features from a tfrecord file and returns them as raw inputs'''
        ds = tf.io.matching_files(self.args.data_dir)
        files = tf.random.shuffle(ds)

        shards = tf.data.Dataset.from_tensor_slices(files)
        dataset = shards.interleave(tf.data.TFRecordDataset)
        dataset = dataset.shuffle(buffer_size=100)

        dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
        dataset = dataset.map(decode_inputs, num_parallel_calls = tf.data.AUTOTUNE)
        dataset = dataset.padded_batch(batch_size = self.args.batch_size)
        return dataset
    
    
    def __len__(self):
        self.samples = self.load_features_from_tfrec(self.args.data_dir)
        cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
        cnt = cnt.numpy()
        return cnt