|
import argparse |
|
import glob |
|
import os |
|
import random |
|
|
|
from sklearn.model_selection import KFold |
|
|
|
|
|
def prepare_folds(cityscapes_path, output_dir, n_splits=3): |
|
""" |
|
Prepares k-fold cross-validation splits for the Cityscapes dataset. |
|
|
|
Args: |
|
cityscapes_path (str): Path to the root Cityscapes directory |
|
(containing leftImg8bit and gtFine). |
|
output_dir (str): Directory to save the split files. |
|
n_splits (int): Number of folds for cross-validation. |
|
""" |
|
leftimg8bit_path = os.path.join(cityscapes_path, "leftImg8bit") |
|
train_img_dir = os.path.join(leftimg8bit_path, "train") |
|
val_img_dir = os.path.join(leftimg8bit_path, "val") |
|
|
|
train_files = [] |
|
|
|
if not os.path.exists(train_img_dir): |
|
print(f"Error: Training image directory not found: {train_img_dir}") |
|
print( |
|
f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/train'." |
|
) |
|
return |
|
for city_folder in os.listdir(train_img_dir): |
|
city_path = os.path.join(train_img_dir, city_folder) |
|
if os.path.isdir(city_path): |
|
train_files.extend(glob.glob(os.path.join(city_path, "*.png"))) |
|
|
|
val_files = [] |
|
|
|
if not os.path.exists(val_img_dir): |
|
print(f"Error: Validation image directory not found: {val_img_dir}") |
|
print( |
|
f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/val'." |
|
) |
|
pass |
|
elif os.path.exists(val_img_dir): |
|
for city_folder in os.listdir(val_img_dir): |
|
city_path = os.path.join(val_img_dir, city_folder) |
|
if os.path.isdir(city_path): |
|
val_files.extend(glob.glob(os.path.join(city_path, "*.png"))) |
|
|
|
if not train_files and not val_files: |
|
print(f"Error: No image files found in {train_img_dir} or {val_img_dir}.") |
|
print("Please check your Cityscapes dataset structure and path.") |
|
return |
|
|
|
all_files = train_files + val_files |
|
|
|
|
|
all_files_relative = [] |
|
for f in all_files: |
|
rel_path = os.path.relpath(f, leftimg8bit_path) |
|
|
|
if rel_path.endswith("_leftImg8bit.png"): |
|
rel_path = rel_path[: -len("_leftImg8bit.png")] |
|
all_files_relative.append(rel_path) |
|
|
|
all_files_relative = sorted(all_files_relative) |
|
|
|
|
|
random.seed(42) |
|
random.shuffle(all_files_relative) |
|
|
|
kf = KFold(n_splits=n_splits, shuffle=False) |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
for i, (train_index, val_index) in enumerate(kf.split(all_files_relative)): |
|
fold_train_files = [all_files_relative[k] for k in train_index] |
|
fold_val_files = [all_files_relative[k] for k in val_index] |
|
|
|
train_file_path = os.path.join(output_dir, f"fold_{i + 1}_train_split.txt") |
|
val_file_path = os.path.join(output_dir, f"fold_{i + 1}_val_split.txt") |
|
|
|
with open(train_file_path, "w") as f: |
|
for item in fold_train_files: |
|
f.write(f"{item}\n") |
|
|
|
with open(val_file_path, "w") as f: |
|
for item in fold_val_files: |
|
f.write(f"{item}\n") |
|
|
|
print(f"Fold {i + 1}: {len(fold_train_files)} train, {len(fold_val_files)} val") |
|
|
|
|
|
print("Sample train files:") |
|
for sample in fold_train_files[:3]: |
|
print(f" {sample}") |
|
print("Sample val files:") |
|
for sample in fold_val_files[:3]: |
|
print(f" {sample}") |
|
|
|
print(f"Split files saved to: {os.path.abspath(output_dir)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Prepare Cityscapes k-fold splits.") |
|
parser.add_argument( |
|
"cityscapes_path", |
|
type=str, |
|
help="Absolute path to the Cityscapes dataset directory.", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default=None, |
|
help="Directory to save the split files. If not provided, a 'splits' folder will be created inside the cityscapes_path.", |
|
) |
|
parser.add_argument( |
|
"--n_splits", type=int, default=3, help="Number of folds (default: 3)." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.output_dir is None: |
|
effective_output_dir = os.path.join(args.cityscapes_path, "splits") |
|
else: |
|
effective_output_dir = args.output_dir |
|
|
|
abs_cityscapes_path = os.path.abspath(args.cityscapes_path) |
|
if not os.path.isdir(abs_cityscapes_path): |
|
print( |
|
f"Error: Cityscapes path not found or is not a directory: {abs_cityscapes_path}" |
|
) |
|
exit(1) |
|
|
|
prepare_folds(abs_cityscapes_path, effective_output_dir, args.n_splits) |
|
|
|
script_name = os.path.basename(__file__) |
|
print("\nTo run this script again, for example:") |
|
print(f"python {script_name} /path/to/your/cityscapes") |
|
if args.output_dir is not None: |
|
print( |
|
f"python {script_name} /path/to/your/cityscapes --output_dir {args.output_dir}" |
|
) |
|
print( |
|
"Replace '/path/to/your/cityscapes' with the actual path to your Cityscapes dataset." |
|
) |
|
|