File size: 1,964 Bytes
cb80c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import shutil
import sys
from sklearn.model_selection import train_test_split


def split_data(data_dir, train_ratio=0.8, seed=42):
    train_dir = os.path.join(data_dir, "train")
    val_dir = os.path.join(data_dir, "val")

    # Ensure the train and val directories exist
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)

    # Iterate over each class folder
    for class_name in os.listdir(data_dir):
        class_path = os.path.join(data_dir, class_name)
        if os.path.isdir(class_path) and class_name not in ["train", "val"]:
            # Get a list of all files in the class directory
            files = os.listdir(class_path)
            files = [f for f in files if os.path.isfile(os.path.join(class_path, f))]

            # Split the files into training and validation sets
            train_files, val_files = train_test_split(
                files, train_size=train_ratio, random_state=seed
            )

            # Create class directories in train and val directories
            train_class_dir = os.path.join(train_dir, class_name)
            val_class_dir = os.path.join(val_dir, class_name)
            os.makedirs(train_class_dir, exist_ok=True)
            os.makedirs(val_class_dir, exist_ok=True)

            # Move training files
            for file in train_files:
                shutil.move(
                    os.path.join(class_path, file), os.path.join(train_class_dir, file)
                )

            # Move validation files
            for file in val_files:
                shutil.move(
                    os.path.join(class_path, file), os.path.join(val_class_dir, file)
                )

    print("Data split complete.")


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python split_data.py <data_dir>")
        sys.exit(1)

    data_dir = sys.argv[1]
    split_data(data_dir)