|
import tensorflow as tf
|
|
import os
|
|
import argparse
|
|
from sklearn.model_selection import StratifiedShuffleSplit
|
|
from tqdm import tqdm
|
|
import uuid
|
|
import random
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
|
|
parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
|
|
parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
|
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
|
parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
|
|
parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
|
|
parser.add_argument('--balance', action='store_true', help='Balance the dataset')
|
|
parser.add_argument('--split_type', type=str, choices=['random', 'stratified'], default='random',
|
|
help='Type of data split (random or stratified)')
|
|
return parser.parse_args()
|
|
|
|
|
|
def process_image(file_path, image_size):
|
|
image = tf.io.read_file(file_path)
|
|
image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
|
|
image = tf.image.resize(image, [image_size, image_size])
|
|
image = tf.clip_by_value(image, 0.0, 1.0)
|
|
return image
|
|
|
|
|
|
def balance_class_images(image_paths, labels, target_count, image_size, label, label_to_index, output_folder):
|
|
print(f"Balancing class '{label}'...")
|
|
label_idx = label_to_index.get(label, None)
|
|
if label_idx is None:
|
|
print(f"Label '{label}' not found in label_to_index.")
|
|
return [], []
|
|
|
|
image_paths = [img for img, lbl in zip(image_paths, labels) if lbl == label_idx]
|
|
num_images = len(image_paths)
|
|
|
|
print(f"Class '{label}' has {num_images} images before balancing.")
|
|
|
|
balanced_images = []
|
|
balanced_labels = []
|
|
|
|
original_count = num_images
|
|
synthetic_count = 0
|
|
|
|
if num_images > target_count:
|
|
balanced_images.extend(random.sample(image_paths, target_count))
|
|
balanced_labels.extend([label_idx] * target_count)
|
|
print(f"Removed {num_images - target_count} images from class '{label}'.")
|
|
elif num_images < target_count:
|
|
balanced_images.extend(image_paths)
|
|
balanced_labels.extend([label_idx] * num_images)
|
|
|
|
num_to_add = target_count - num_images
|
|
print(f"Class '{label}' needs {num_to_add} additional images for balancing.")
|
|
|
|
while num_to_add > 0:
|
|
img_path = random.choice(image_paths)
|
|
image = process_image(img_path, image_size)
|
|
|
|
for _ in range(min(num_to_add, 5)):
|
|
augmented_image = augment_image(image)
|
|
balanced_images.append(augmented_image)
|
|
balanced_labels.append(label_idx)
|
|
num_to_add -= 1
|
|
synthetic_count += 1
|
|
|
|
print(f"Added {synthetic_count} augmented images to class '{label}'.")
|
|
print(f"Class '{label}' has {len(balanced_images)} images after balancing.")
|
|
|
|
class_folder = os.path.join(output_folder, str(label_idx))
|
|
if not os.path.exists(class_folder):
|
|
os.makedirs(class_folder)
|
|
|
|
for i, img in enumerate(balanced_images):
|
|
file_name = f"{uuid.uuid4()}.png"
|
|
file_path = os.path.join(class_folder, file_name)
|
|
save_image(img, file_path)
|
|
|
|
print(f"Saved {len(balanced_images)} images for class '{label}' (Original: {original_count}, Synthetic: {synthetic_count}).")
|
|
|
|
return balanced_images, balanced_labels
|
|
|
|
|
|
def save_image(image, file_path):
|
|
if isinstance(image, str):
|
|
image = process_image(image, image_size)
|
|
if isinstance(image, tf.Tensor):
|
|
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
|
|
image = tf.image.encode_png(image)
|
|
else:
|
|
raise ValueError("Expected image to be a TensorFlow tensor, but got a different type.")
|
|
|
|
tf.io.write_file(file_path, image)
|
|
|
|
|
|
def augment_image(image):
|
|
|
|
image = tf.image.random_flip_left_right(image)
|
|
image = tf.image.random_flip_up_down(image)
|
|
image = tf.image.random_brightness(image, max_delta=0.1)
|
|
image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
|
|
image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
|
|
image = tf.image.random_hue(image, max_delta=0.1)
|
|
return image
|
|
|
|
|
|
def create_datagens():
|
|
return [augment_image]
|
|
|
|
|
|
def balance_data(images, labels, target_count, image_size, unique_labels, label_to_index, output_folder):
|
|
print(f"Balancing data: Target count per class = {target_count}")
|
|
|
|
all_balanced_images = []
|
|
all_balanced_labels = []
|
|
|
|
for label in tqdm(unique_labels, desc="Balancing classes"):
|
|
num_images = len([img for img, lbl in zip(images, labels) if lbl == label_to_index.get(label, -1)])
|
|
balanced_images, balanced_labels = balance_class_images(
|
|
images, labels, target_count, image_size, label, label_to_index, output_folder
|
|
)
|
|
all_balanced_images.extend(balanced_images)
|
|
all_balanced_labels.extend(balanced_labels)
|
|
|
|
total_original_images = sum(1 for img in all_balanced_images if isinstance(img, str))
|
|
total_synthetic_images = len(all_balanced_images) - total_original_images
|
|
|
|
print(f"\nTotal saved images: {len(all_balanced_images)} (Original: {total_original_images}, Synthetic: {total_synthetic_images})")
|
|
|
|
return all_balanced_images, all_balanced_labels
|
|
|
|
|
|
def tf_augment_image(file_path, label):
|
|
image = tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(file_path)), [image_size, image_size])
|
|
image = tf.cast(image, tf.float32) / 255.0
|
|
augmented_image = augment_image(image)
|
|
return augmented_image, label
|
|
|
|
|
|
def map_fn(file_path, label):
|
|
image, label = tf.py_function(tf_augment_image, [file_path, label], [tf.float32, tf.int32])
|
|
image.set_shape([image_size, image_size, 3])
|
|
label.set_shape([])
|
|
return image, label
|
|
|
|
|
|
def load_and_save_splits(path, image_size, batch_size, balance, datagens, target_folder, split_type):
|
|
all_images = []
|
|
labels = []
|
|
|
|
for class_folder in os.listdir(path):
|
|
class_path = os.path.join(path, class_folder)
|
|
if os.path.isdir(class_path):
|
|
for img_file in os.listdir(class_path):
|
|
img_path = os.path.join(class_path, img_file)
|
|
all_images.append(img_path)
|
|
labels.append(class_folder)
|
|
|
|
print(f"Loaded {len(all_images)} images across {len(set(labels))} classes.")
|
|
print(f"Labels found: {set(labels)}")
|
|
|
|
unique_labels = list(set(labels))
|
|
label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
|
|
encoded_labels = [label_to_index[label] for label in labels]
|
|
|
|
print(f"Label to index mapping: {label_to_index}")
|
|
|
|
if split_type == 'stratified':
|
|
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
|
train_indices, test_indices = next(sss.split(all_images, encoded_labels))
|
|
else:
|
|
total_images = len(all_images)
|
|
indices = list(range(total_images))
|
|
random.shuffle(indices)
|
|
train_indices = indices[:int(0.8 * total_images)]
|
|
test_indices = indices[int(0.8 * total_images):]
|
|
|
|
train_files = [all_images[i] for i in train_indices]
|
|
train_labels = [encoded_labels[i] for i in train_indices]
|
|
test_files = [all_images[i] for i in test_indices]
|
|
test_labels = [encoded_labels[i] for i in test_indices]
|
|
|
|
|
|
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
|
|
val_indices, test_indices = next(sss_val.split(test_files, test_labels))
|
|
|
|
val_files = [test_files[i] for i in val_indices]
|
|
val_labels = [test_labels[i] for i in val_indices]
|
|
test_files = [test_files[i] for i in test_indices]
|
|
test_labels = [test_labels[i] for i in test_indices]
|
|
|
|
|
|
for split_name, file_list, labels_list in [("train", train_files, train_labels), ("val", val_files, val_labels), ("test", test_files, test_labels)]:
|
|
split_folder = os.path.join(target_folder, split_name)
|
|
os.makedirs(split_folder, exist_ok=True)
|
|
with open(os.path.join(split_folder, f"{split_name}_list.txt"), 'w') as file_list_file:
|
|
for img_path, label in zip(file_list, labels_list):
|
|
label_folder = os.path.join(split_folder, str(label))
|
|
if not os.path.exists(label_folder):
|
|
os.makedirs(label_folder)
|
|
file_list_file.write(f"{img_path}\n")
|
|
save_image(img_path, os.path.join(label_folder, f"{uuid.uuid4()}.png"))
|
|
|
|
print(f"Saved splits: train: {len(train_files)}, val: {len(val_files)}, test: {len(test_files)}.")
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
load_and_save_splits(args.path, args.dim, args.batch_size, args.balance, create_datagens(), args.target_folder, args.split_type)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|