File size: 3,465 Bytes
81170fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm
import logging

logger = logging.getLogger(__name__)


def prefetch(dataset, n_prefetch):
    # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
    ds_iter = iter(dataset)
    ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
                  ds_iter)
    if n_prefetch:
        ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
    return ds_iter


def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
    """

    Args:
        data_dir (str): Root directory of the dataset.
        img_size (int): Image size for training.
        img_channels (int): Number of image channels.
        num_classes (int): Number of classes, 0 for no classes.
        num_local_devices (int): Number of devices.
        batch_size (int): Batch size (per device).
        shuffle_buffer (int): Buffer used for shuffling the dataset.

    Returns:
        (tf.data.Dataset): Dataset.
    """

    def pre_process(serialized_example):
        feature = {'height': tf.io.FixedLenFeature([], tf.int64),
                   'width': tf.io.FixedLenFeature([], tf.int64),
                   'channels': tf.io.FixedLenFeature([], tf.int64),
                   'image': tf.io.FixedLenFeature([], tf.string),
                   'label': tf.io.FixedLenFeature([], tf.int64)}
        example = tf.io.parse_single_example(serialized_example, feature)

        height = tf.cast(example['height'], dtype=tf.int64)
        width = tf.cast(example['width'], dtype=tf.int64)
        channels = tf.cast(example['channels'], dtype=tf.int64)

        image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
        image = tf.reshape(image, shape=[height, width, channels])

        image = tf.cast(image, dtype='float32')
        image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
        image = tf.image.random_flip_left_right(image)
        
        image = (image - 127.5) / 127.5
        
        label = tf.one_hot(example['label'], num_classes)
        return {'image': image, 'label': label}

    def shard(data):
        # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
        # because the first dimension will be mapped across devices using jax.pmap
        data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels])
        data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes])
        return data

    logger.info('Loading TFRecord...')
    with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
        dataset_info = json.load(fin)

    ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
    ds = ds.shard(jax.process_count(), jax.process_index())
    ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
    ds = ds.map(pre_process, tf.data.AUTOTUNE)
    ds = ds.batch(batch_size * num_local_devices, drop_remainder=True)  # uses per-worker batch size
    ds = ds.map(shard, tf.data.AUTOTUNE)
    ds = ds.prefetch(1)  # prefetches the next batch
    return ds, dataset_info