import tensorflow as tf import numpy as np from PIL import Image from typing import Sequence from tqdm import tqdm import argparse import json import os import logging logger = logging.getLogger(__name__) def images_to_tfrecords(image_dir, data_dir, has_labels): """ Converts a folder of images to a TFRecord file. The image directory should have one of the following structures: If has_labels = False, image_dir should look like this: path/to/image_dir/ 0.jpg 1.jpg 2.jpg 4.jpg ... If has_labels = True, image_dir should look like this: path/to/image_dir/ label0/ 0.jpg 1.jpg ... label1/ a.jpg b.jpg c.jpg ... ... The labels will be label0 -> 0, label1 -> 1. Args: image_dir (str): Path to images. data_dir (str): Path where the TFrecords dataset is stored. has_labels (bool): If True, 'image_dir' contains label directories. Returns: (dict): Dataset info. """ def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) os.makedirs(data_dir, exist_ok=True) writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords')) num_examples = 0 num_classes = 0 if has_labels: for label_dir in os.listdir(image_dir): if not os.path.isdir(os.path.join(image_dir, label_dir)): logger.warning('The image directory should contain one directory for each label.') logger.warning('These label directories should contain the image files.') if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')): os.remove(os.path.join(data_dir, 'dataset.tfrecords')) return for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))): file_format = img_file[img_file.rfind('.') + 1:] if file_format not in ['png', 'jpg', 'jpeg']: continue #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) img = Image.open(os.path.join(image_dir, label_dir, img_file)) img = np.array(img, dtype=np.uint8) height = img.shape[0] width = img.shape[1] channels = img.shape[2] img_encoded = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(height), 'width': _int64_feature(width), 'channels': _int64_feature(channels), 'image': _bytes_feature(img_encoded), 'label': _int64_feature(num_classes)})) writer.write(example.SerializeToString()) num_examples += 1 num_classes += 1 else: for img_file in tqdm(os.listdir(os.path.join(image_dir))): file_format = img_file[img_file.rfind('.') + 1:] if file_format not in ['png', 'jpg', 'jpeg']: continue #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) img = Image.open(os.path.join(image_dir, img_file)) img = np.array(img, dtype=np.uint8) height = img.shape[0] width = img.shape[1] channels = img.shape[2] img_encoded = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(height), 'width': _int64_feature(width), 'channels': _int64_feature(channels), 'image': _bytes_feature(img_encoded), 'label': _int64_feature(num_classes)})) # dummy label writer.write(example.SerializeToString()) num_examples += 1 writer.close() dataset_info = {'num_examples': num_examples, 'num_classes': num_classes} with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout: json.dump(dataset_info, fout) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_dir', type=str, help='Path to the image directory.') parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.') parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.') args = parser.parse_args() images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)