git / segmentation /prepare_cityscapes_folds.py
Leonardo6's picture
Add files using upload-large-folder tool
63e060d verified
import argparse
import glob
import os
import random
from sklearn.model_selection import KFold
def prepare_folds(cityscapes_path, output_dir, n_splits=3):
"""
Prepares k-fold cross-validation splits for the Cityscapes dataset.
Args:
cityscapes_path (str): Path to the root Cityscapes directory
(containing leftImg8bit and gtFine).
output_dir (str): Directory to save the split files.
n_splits (int): Number of folds for cross-validation.
"""
leftimg8bit_path = os.path.join(cityscapes_path, "leftImg8bit")
train_img_dir = os.path.join(leftimg8bit_path, "train")
val_img_dir = os.path.join(leftimg8bit_path, "val")
train_files = []
# Check if train_img_dir exists before listing its contents
if not os.path.exists(train_img_dir):
print(f"Error: Training image directory not found: {train_img_dir}")
print(
f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/train'."
)
return
for city_folder in os.listdir(train_img_dir):
city_path = os.path.join(train_img_dir, city_folder)
if os.path.isdir(city_path):
train_files.extend(glob.glob(os.path.join(city_path, "*.png")))
val_files = []
# Check if val_img_dir exists
if not os.path.exists(val_img_dir):
print(f"Error: Validation image directory not found: {val_img_dir}")
print(
f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/val'."
)
pass
elif os.path.exists(val_img_dir):
for city_folder in os.listdir(val_img_dir):
city_path = os.path.join(val_img_dir, city_folder)
if os.path.isdir(city_path):
val_files.extend(glob.glob(os.path.join(city_path, "*.png")))
if not train_files and not val_files:
print(f"Error: No image files found in {train_img_dir} or {val_img_dir}.")
print("Please check your Cityscapes dataset structure and path.")
return
all_files = train_files + val_files
# 关键修改:生成正确的相对路径格式
all_files_relative = []
for f in all_files:
rel_path = os.path.relpath(f, leftimg8bit_path)
# 去掉 _leftImg8bit.png 后缀
if rel_path.endswith("_leftImg8bit.png"):
rel_path = rel_path[: -len("_leftImg8bit.png")]
all_files_relative.append(rel_path)
all_files_relative = sorted(all_files_relative)
# Ensure consistent shuffling for reproducibility if needed
random.seed(42)
random.shuffle(all_files_relative)
kf = KFold(n_splits=n_splits, shuffle=False) # Shuffle is already done
os.makedirs(output_dir, exist_ok=True)
for i, (train_index, val_index) in enumerate(kf.split(all_files_relative)):
fold_train_files = [all_files_relative[k] for k in train_index]
fold_val_files = [all_files_relative[k] for k in val_index]
train_file_path = os.path.join(output_dir, f"fold_{i + 1}_train_split.txt")
val_file_path = os.path.join(output_dir, f"fold_{i + 1}_val_split.txt")
with open(train_file_path, "w") as f:
for item in fold_train_files:
f.write(f"{item}\n") # 修复:使用单个 \n
with open(val_file_path, "w") as f:
for item in fold_val_files:
f.write(f"{item}\n") # 修复:使用单个 \n
print(f"Fold {i + 1}: {len(fold_train_files)} train, {len(fold_val_files)} val")
# 添加调试信息
print("Sample train files:")
for sample in fold_train_files[:3]:
print(f" {sample}")
print("Sample val files:")
for sample in fold_val_files[:3]:
print(f" {sample}")
print(f"Split files saved to: {os.path.abspath(output_dir)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare Cityscapes k-fold splits.")
parser.add_argument(
"cityscapes_path",
type=str,
help="Absolute path to the Cityscapes dataset directory.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to save the split files. If not provided, a 'splits' folder will be created inside the cityscapes_path.",
)
parser.add_argument(
"--n_splits", type=int, default=3, help="Number of folds (default: 3)."
)
args = parser.parse_args()
if args.output_dir is None:
effective_output_dir = os.path.join(args.cityscapes_path, "splits")
else:
effective_output_dir = args.output_dir
abs_cityscapes_path = os.path.abspath(args.cityscapes_path)
if not os.path.isdir(abs_cityscapes_path):
print(
f"Error: Cityscapes path not found or is not a directory: {abs_cityscapes_path}"
)
exit(1)
prepare_folds(abs_cityscapes_path, effective_output_dir, args.n_splits)
script_name = os.path.basename(__file__)
print("\nTo run this script again, for example:")
print(f"python {script_name} /path/to/your/cityscapes")
if args.output_dir is not None:
print(
f"python {script_name} /path/to/your/cityscapes --output_dir {args.output_dir}"
)
print(
"Replace '/path/to/your/cityscapes' with the actual path to your Cityscapes dataset."
)