Spaces:
Running
Running
File size: 9,832 Bytes
b7f710c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import os
import zipfile
import requests
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import imgaug.augmenters as iaa
import sys
import argparse
import shutil
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.slimface.data.data_processing import process_image
def download_and_split_kaggle_dataset(
dataset_slug,
base_dir="data",
augment=False,
random_state=42,
test_split_rate=0.2,
rotation_range=15,
source_subdir="Original Images/Original Images",
delete_raw=False
):
"""Download a Kaggle dataset, split it into train/validation sets, and process images for face recognition.
Skips downloading if ZIP exists and unzipping if raw folder contains files.
Optionally deletes the raw folder to save storage.
Args:
dataset_slug (str): Dataset slug in 'username/dataset-name' format.
base_dir (str): Base directory for storing dataset.
augment (bool): Whether to apply data augmentation to training images.
random_state (int): Random seed for reproducibility in train-test split.
test_split_rate (float): Proportion of data to use for validation (between 0 and 1).
rotation_range (int): Maximum rotation angle in degrees for augmentation.
source_subdir (str): Subdirectory within raw_dir containing images.
delete_raw (bool): Whether to delete the raw folder after processing to save storage.
Raises:
ValueError: If test_split_rate is not between 0 and 1 or dataset_slug is invalid.
FileNotFoundError: If source directory is not found.
Exception: If dataset download fails or other errors occur.
"""
try:
# Validate test_split_rate
if not 0 < test_split_rate < 1:
raise ValueError("test_split_rate must be between 0 and 1")
# Set up directories
raw_dir = os.path.join(base_dir, "raw")
processed_dir = os.path.join(base_dir, "processed_ds")
train_dir = os.path.join(processed_dir, "train_data")
val_dir = os.path.join(processed_dir, "val_data")
zip_path = os.path.join(raw_dir, "dataset.zip")
os.makedirs(raw_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)
# Check if ZIP file already exists
if os.path.exists(zip_path):
print(f"ZIP file already exists at {zip_path}, skipping download.")
else:
# Download dataset with progress bar
username, dataset_name = dataset_slug.split("/")
if not (username and dataset_name):
raise ValueError("Invalid dataset slug format. Expected 'username/dataset-name'")
dataset_url = f"https://www.kaggle.com/api/v1/datasets/download/{username}/{dataset_name}"
print(f"Downloading dataset {dataset_slug}...")
response = requests.get(dataset_url, stream=True)
if response.status_code != 200:
raise Exception(f"Failed to download dataset: {response.status_code}")
total_size = int(response.headers.get("content-length", 0))
with open(zip_path, "wb") as file, tqdm(
desc="Downloading dataset",
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
file.write(chunk)
pbar.update(len(chunk))
# # Check if raw directory contains files, excluding the ZIP file
# zip_filename = os.path.basename(zip_path)
# if os.path.exists(raw_dir) and any(file != zip_filename for file in os.listdir(raw_dir)):
# print(f"Raw directory {raw_dir} already contains files, skipping extraction.")
# else:
# Extract dataset
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(raw_dir)
# Define source directory
source_dir = os.path.join(raw_dir, source_subdir)
if not os.path.exists(source_dir):
raise FileNotFoundError(f"Source directory {source_dir} not found")
# Group files by person (subfolder names)
person_files = {}
for person in os.listdir(source_dir):
person_dir = os.path.join(source_dir, person)
if os.path.isdir(person_dir):
person_files[person] = [
f for f in os.listdir(person_dir)
if os.path.isfile(os.path.join(person_dir, f))
and f.lower().endswith((".png", ".jpg", ".jpeg"))
]
# Define augmentation pipeline
if augment:
aug = iaa.Sequential([
iaa.Fliplr(p=1.0),
iaa.Sometimes(
0.5,
iaa.Affine(rotate=(-rotation_range, rotation_range))
),
])
else:
aug = None
# Process and split files with progress bar
total_files = sum(len(images) for images in person_files.values())
with tqdm(total=total_files, desc="Processing and copying files", unit="file") as pbar:
for person, images in person_files.items():
# Set up directories for this person
train_person_dir = os.path.join(train_dir, person)
val_person_dir = os.path.join(val_dir, person)
temp_dir = os.path.join(processed_dir, "temp")
os.makedirs(train_person_dir, exist_ok=True)
os.makedirs(val_person_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)
all_image_filenames = []
# Process images and create augmentations before splitting
for img in images:
src_path = os.path.join(source_dir, person, img)
saved_images = process_image(src_path, temp_dir, aug if augment else None)
all_image_filenames.extend(saved_images)
pbar.update(1)
# Split all images (original and augmented) for this person
train_images_filenames, val_images_filenames = train_test_split(
all_image_filenames,
test_size=test_split_rate,
random_state=random_state,
)
# Move images to final train/val directories
for img in all_image_filenames:
src = os.path.join(temp_dir, img)
if not os.path.exists(src):
print(f"Warning: File {src} not found, skipping.")
continue
if img in train_images_filenames:
dst = os.path.join(train_person_dir, img)
else:
dst = os.path.join(val_person_dir, img)
os.rename(src, dst)
# Clean up temporary directory for this person
shutil.rmtree(temp_dir, ignore_errors=True)
print(f"\nCleaned up temp directory for {person}")
# Optionally delete raw folder to save storage
if delete_raw:
print(f"Deleting raw folder {raw_dir} to save storage...")
shutil.rmtree(raw_dir, ignore_errors=True)
print(f"Raw folder {raw_dir} deleted.")
print(f"Dataset {dataset_slug} downloaded, extracted, processed, and split successfully!")
except Exception as e:
print(f"Error processing dataset: {e}")
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and process a Kaggle dataset for face recognition.")
parser.add_argument(
"--dataset_slug",
type=str,
default="vasukipatel/face-recognition-dataset",
help="Kaggle dataset slug in 'username/dataset-name' format"
)
parser.add_argument(
"--base_dir",
type=str,
default="./data",
help="Base directory for storing dataset"
)
parser.add_argument(
"--augment",
action="store_true",
help="Enable data augmentation"
)
parser.add_argument(
"--random_state",
type=int,
default=42,
help="Random seed for train-test split reproducibility"
)
parser.add_argument(
"--test_split_rate",
type=float,
default=0.2,
help="Proportion of data for validation (between 0 and 1)"
)
parser.add_argument(
"--rotation_range",
type=int,
default=15,
help="Maximum rotation angle in degrees for augmentation"
)
parser.add_argument(
"--source_subdir",
type=str,
default="Original Images/Original Images",
help="Subdirectory within raw_dir containing images"
)
parser.add_argument(
"--delete_raw",
action="store_true",
help="Delete the raw folder after processing to save storage"
)
args = parser.parse_args()
download_and_split_kaggle_dataset(
dataset_slug=args.dataset_slug,
base_dir=args.base_dir,
augment=args.augment,
random_state=args.random_state,
test_split_rate=args.test_split_rate,
rotation_range=args.rotation_range,
source_subdir=args.source_subdir,
delete_raw=args.delete_raw
)
|