Spaces:
Running
Running
import os | |
import zipfile | |
import requests | |
import json | |
from tqdm import tqdm | |
from sklearn.model_selection import train_test_split | |
import imgaug.augmenters as iaa | |
import sys | |
import argparse | |
import shutil | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from src.slimface.data.data_processing import process_image | |
def download_and_split_kaggle_dataset( | |
dataset_slug, | |
base_dir="data", | |
augment=False, | |
random_state=42, | |
test_split_rate=0.2, | |
rotation_range=15, | |
source_subdir="Original Images/Original Images", | |
delete_raw=False | |
): | |
"""Download a Kaggle dataset, split it into train/validation sets, and process images for face recognition. | |
Skips downloading if ZIP exists and unzipping if raw folder contains files. | |
Optionally deletes the raw folder to save storage. | |
Args: | |
dataset_slug (str): Dataset slug in 'username/dataset-name' format. | |
base_dir (str): Base directory for storing dataset. | |
augment (bool): Whether to apply data augmentation to training images. | |
random_state (int): Random seed for reproducibility in train-test split. | |
test_split_rate (float): Proportion of data to use for validation (between 0 and 1). | |
rotation_range (int): Maximum rotation angle in degrees for augmentation. | |
source_subdir (str): Subdirectory within raw_dir containing images. | |
delete_raw (bool): Whether to delete the raw folder after processing to save storage. | |
Raises: | |
ValueError: If test_split_rate is not between 0 and 1 or dataset_slug is invalid. | |
FileNotFoundError: If source directory is not found. | |
Exception: If dataset download fails or other errors occur. | |
""" | |
try: | |
# Validate test_split_rate | |
if not 0 < test_split_rate < 1: | |
raise ValueError("test_split_rate must be between 0 and 1") | |
# Set up directories | |
raw_dir = os.path.join(base_dir, "raw") | |
processed_dir = os.path.join(base_dir, "processed_ds") | |
train_dir = os.path.join(processed_dir, "train_data") | |
val_dir = os.path.join(processed_dir, "val_data") | |
zip_path = os.path.join(raw_dir, "dataset.zip") | |
os.makedirs(raw_dir, exist_ok=True) | |
os.makedirs(processed_dir, exist_ok=True) | |
# Check if ZIP file already exists | |
if os.path.exists(zip_path): | |
print(f"ZIP file already exists at {zip_path}, skipping download.") | |
else: | |
# Download dataset with progress bar | |
username, dataset_name = dataset_slug.split("/") | |
if not (username and dataset_name): | |
raise ValueError("Invalid dataset slug format. Expected 'username/dataset-name'") | |
dataset_url = f"https://www.kaggle.com/api/v1/datasets/download/{username}/{dataset_name}" | |
print(f"Downloading dataset {dataset_slug}...") | |
response = requests.get(dataset_url, stream=True) | |
if response.status_code != 200: | |
raise Exception(f"Failed to download dataset: {response.status_code}") | |
total_size = int(response.headers.get("content-length", 0)) | |
with open(zip_path, "wb") as file, tqdm( | |
desc="Downloading dataset", | |
total=total_size, | |
unit="B", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as pbar: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
file.write(chunk) | |
pbar.update(len(chunk)) | |
# # Check if raw directory contains files, excluding the ZIP file | |
# zip_filename = os.path.basename(zip_path) | |
# if os.path.exists(raw_dir) and any(file != zip_filename for file in os.listdir(raw_dir)): | |
# print(f"Raw directory {raw_dir} already contains files, skipping extraction.") | |
# else: | |
# Extract dataset | |
print("Extracting dataset...") | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(raw_dir) | |
# Define source directory | |
source_dir = os.path.join(raw_dir, source_subdir) | |
if not os.path.exists(source_dir): | |
raise FileNotFoundError(f"Source directory {source_dir} not found") | |
# Group files by person (subfolder names) | |
person_files = {} | |
for person in os.listdir(source_dir): | |
person_dir = os.path.join(source_dir, person) | |
if os.path.isdir(person_dir): | |
person_files[person] = [ | |
f for f in os.listdir(person_dir) | |
if os.path.isfile(os.path.join(person_dir, f)) | |
and f.lower().endswith((".png", ".jpg", ".jpeg")) | |
] | |
# Define augmentation pipeline | |
if augment: | |
aug = iaa.Sequential([ | |
iaa.Fliplr(p=1.0), | |
iaa.Sometimes( | |
0.5, | |
iaa.Affine(rotate=(-rotation_range, rotation_range)) | |
), | |
]) | |
else: | |
aug = None | |
# Process and split files with progress bar | |
total_files = sum(len(images) for images in person_files.values()) | |
with tqdm(total=total_files, desc="Processing and copying files", unit="file") as pbar: | |
for person, images in person_files.items(): | |
# Set up directories for this person | |
train_person_dir = os.path.join(train_dir, person) | |
val_person_dir = os.path.join(val_dir, person) | |
temp_dir = os.path.join(processed_dir, "temp") | |
os.makedirs(train_person_dir, exist_ok=True) | |
os.makedirs(val_person_dir, exist_ok=True) | |
os.makedirs(temp_dir, exist_ok=True) | |
all_image_filenames = [] | |
# Process images and create augmentations before splitting | |
for img in images: | |
src_path = os.path.join(source_dir, person, img) | |
saved_images = process_image(src_path, temp_dir, aug if augment else None) | |
all_image_filenames.extend(saved_images) | |
pbar.update(1) | |
# Split all images (original and augmented) for this person | |
train_images_filenames, val_images_filenames = train_test_split( | |
all_image_filenames, | |
test_size=test_split_rate, | |
random_state=random_state, | |
) | |
# Move images to final train/val directories | |
for img in all_image_filenames: | |
src = os.path.join(temp_dir, img) | |
if not os.path.exists(src): | |
print(f"Warning: File {src} not found, skipping.") | |
continue | |
if img in train_images_filenames: | |
dst = os.path.join(train_person_dir, img) | |
else: | |
dst = os.path.join(val_person_dir, img) | |
os.rename(src, dst) | |
# Clean up temporary directory for this person | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
print(f"\nCleaned up temp directory for {person}") | |
# Optionally delete raw folder to save storage | |
if delete_raw: | |
print(f"Deleting raw folder {raw_dir} to save storage...") | |
shutil.rmtree(raw_dir, ignore_errors=True) | |
print(f"Raw folder {raw_dir} deleted.") | |
print(f"Dataset {dataset_slug} downloaded, extracted, processed, and split successfully!") | |
except Exception as e: | |
print(f"Error processing dataset: {e}") | |
raise | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Download and process a Kaggle dataset for face recognition.") | |
parser.add_argument( | |
"--dataset_slug", | |
type=str, | |
default="vasukipatel/face-recognition-dataset", | |
help="Kaggle dataset slug in 'username/dataset-name' format" | |
) | |
parser.add_argument( | |
"--base_dir", | |
type=str, | |
default="./data", | |
help="Base directory for storing dataset" | |
) | |
parser.add_argument( | |
"--augment", | |
action="store_true", | |
help="Enable data augmentation" | |
) | |
parser.add_argument( | |
"--random_state", | |
type=int, | |
default=42, | |
help="Random seed for train-test split reproducibility" | |
) | |
parser.add_argument( | |
"--test_split_rate", | |
type=float, | |
default=0.2, | |
help="Proportion of data for validation (between 0 and 1)" | |
) | |
parser.add_argument( | |
"--rotation_range", | |
type=int, | |
default=15, | |
help="Maximum rotation angle in degrees for augmentation" | |
) | |
parser.add_argument( | |
"--source_subdir", | |
type=str, | |
default="Original Images/Original Images", | |
help="Subdirectory within raw_dir containing images" | |
) | |
parser.add_argument( | |
"--delete_raw", | |
action="store_true", | |
help="Delete the raw folder after processing to save storage" | |
) | |
args = parser.parse_args() | |
download_and_split_kaggle_dataset( | |
dataset_slug=args.dataset_slug, | |
base_dir=args.base_dir, | |
augment=args.augment, | |
random_state=args.random_state, | |
test_split_rate=args.test_split_rate, | |
rotation_range=args.rotation_range, | |
source_subdir=args.source_subdir, | |
delete_raw=args.delete_raw | |
) | |