File size: 4,608 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import cv2
import shutil
import imageio
import numpy as np
from glob import glob
from pathlib import Path
from typing import List, Tuple, Optional

def validate_dimensions(width: int, height: int, stride: int = 32) -> Tuple[int, int]:
    if height % stride != 0 or width % stride != 0:
        new_height = ((height // stride + 1) * stride 
                     if height % stride != 0 else height)
        new_width = ((width // stride + 1) * stride 
                    if width % stride != 0 else width)
        print(f'Adjusted dimensions to: {new_height}H x {new_width}W')
    return width, height

def calc_image_size(image: np.ndarray, target_size: int) -> Tuple[int, int]:
    height, width = image.shape[:2]
    aspect_ratio = width / height
    
    if aspect_ratio >= 1:
        new_width = target_size
        new_height = int(target_size / aspect_ratio)
    else:
        new_height = target_size
        new_width = int(target_size * aspect_ratio)
        
    return validate_dimensions(new_width, new_height)

def convert_coordinates(transform: np.ndarray, x: float, y: float) -> Tuple[float, float]:
    transformed = transform @ np.array([x, y, 1])
    return transformed[0], transformed[1]

def list_images(directory: str, mask_format: bool = False) -> List[str]:
    extensions = ['*.png', '*.PNG'] if mask_format else [
        '*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff',
        '*.JPG', '*.JPEG', '*.PNG', '*.TIF', '*.TIFF'
    ]
    
    image_paths = []
    for ext in extensions:
        image_paths.extend(glob(os.path.join(directory, ext)))
    
    return sorted(list(set(image_paths)))

def prepare_dataset_split(root_dir: str, train_ratio: float = 0.7, 
                         generate_empty_masks: bool = False) -> None:
    image_dir = os.path.join(root_dir, 'Images')
    mask_dir = os.path.join(root_dir, 'Masks')
    
    if not all(os.path.exists(d) for d in [image_dir, mask_dir]):
        raise Exception("Required 'Images' and 'Masks' directories not found")
    
    image_paths = np.array(list_images(image_dir))
    mask_paths = np.array(list_images(mask_dir, mask_format=True))
    
    if generate_empty_masks:
        temp_dir = os.path.join(mask_dir, 'temp')
        create_empty_masks(image_dir, outdir=temp_dir)
        
        for mask_path in list_images(temp_dir, mask_format=True):
            target_path = os.path.join(mask_dir, os.path.basename(mask_path))
            if not os.path.exists(target_path):
                shutil.move(mask_path, target_path)
        
        shutil.rmtree(temp_dir)
        mask_paths = np.array(list_images(mask_dir, mask_format=True))
    
    if len(image_paths) != len(mask_paths):
        raise Exception(f"Unmatched images ({len(image_paths)}) and masks ({len(mask_paths)})")
    
    train_ratio = float(train_ratio)
    if not (0 < train_ratio <= 1):
        raise ValueError(f"Invalid train ratio: {train_ratio}")
    
    train_size = int(np.floor(train_ratio * len(image_paths)))
    indices = np.random.permutation(len(image_paths))
    
    splits = {
        'train': {'indices': indices[:train_size]},
        'val': {'indices': indices[train_size:]} if train_ratio < 1 else None
    }
    
    for split_name, split_data in splits.items():
        if split_data is None:
            continue
            
        split_dir = os.path.join(root_dir, split_name)
        for subdir in ['Images', 'Masks']:
            subdir_path = os.path.join(split_dir, subdir)
            os.makedirs(subdir_path, exist_ok=True)
            
            sources = image_paths if subdir == 'Images' else mask_paths
            for idx in split_data['indices']:
                source = sources[idx]
                destination = os.path.join(subdir_path, os.path.basename(source))
                shutil.copyfile(source, destination)
        
        print(f"Created {split_name} split with {len(split_data['indices'])} samples")

def create_empty_masks(image_dir: str, pixel_value: int = 0, 
                      outdir: Optional[str] = None) -> str:
    outdir = outdir or os.path.join(image_dir, 'Masks')
    os.makedirs(outdir, exist_ok=True)
    
    image_paths = list_images(image_dir)
    print(f"Generating {len(image_paths)} empty masks...")
    
    for image_path in image_paths:
        image = imageio.imread(image_path)
        mask = np.full((image.shape[0], image.shape[1]), pixel_value, dtype='uint8')
        
        output_path = os.path.join(outdir, 
                                 f"{Path(image_path).stem}.png")
        imageio.imwrite(output_path, mask)
    
    return outdir