stylegan2-flax-tpu / dataset_utils /images_to_tfrecords.py
akhaliq's picture
akhaliq HF staff
add files
81170fd
raw
history blame
No virus
4.83 kB
import tensorflow as tf
import numpy as np
from PIL import Image
from typing import Sequence
from tqdm import tqdm
import argparse
import json
import os
import logging
logger = logging.getLogger(__name__)
def images_to_tfrecords(image_dir, data_dir, has_labels):
"""
Converts a folder of images to a TFRecord file.
The image directory should have one of the following structures:
If has_labels = False, image_dir should look like this:
path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
If has_labels = True, image_dir should look like this:
path/to/image_dir/
label0/
0.jpg
1.jpg
...
label1/
a.jpg
b.jpg
c.jpg
...
...
The labels will be label0 -> 0, label1 -> 1.
Args:
image_dir (str): Path to images.
data_dir (str): Path where the TFrecords dataset is stored.
has_labels (bool): If True, 'image_dir' contains label directories.
Returns:
(dict): Dataset info.
"""
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
os.makedirs(data_dir, exist_ok=True)
writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
num_examples = 0
num_classes = 0
if has_labels:
for label_dir in os.listdir(image_dir):
if not os.path.isdir(os.path.join(image_dir, label_dir)):
logger.warning('The image directory should contain one directory for each label.')
logger.warning('These label directories should contain the image files.')
if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
return
for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
file_format = img_file[img_file.rfind('.') + 1:]
if file_format not in ['png', 'jpg', 'jpeg']:
continue
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
img = Image.open(os.path.join(image_dir, label_dir, img_file))
img = np.array(img, dtype=np.uint8)
height = img.shape[0]
width = img.shape[1]
channels = img.shape[2]
img_encoded = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'image': _bytes_feature(img_encoded),
'label': _int64_feature(num_classes)}))
writer.write(example.SerializeToString())
num_examples += 1
num_classes += 1
else:
for img_file in tqdm(os.listdir(os.path.join(image_dir))):
file_format = img_file[img_file.rfind('.') + 1:]
if file_format not in ['png', 'jpg', 'jpeg']:
continue
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
img = Image.open(os.path.join(image_dir, img_file))
img = np.array(img, dtype=np.uint8)
height = img.shape[0]
width = img.shape[1]
channels = img.shape[2]
img_encoded = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'image': _bytes_feature(img_encoded),
'label': _int64_feature(num_classes)})) # dummy label
writer.write(example.SerializeToString())
num_examples += 1
writer.close()
dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
json.dump(dataset_info, fout)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
args = parser.parse_args()
images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)