File size: 5,207 Bytes
e9629ef
d8aafaa
 
e9629ef
 
 
 
 
 
 
d8aafaa
 
5ce12fa
e9629ef
 
 
237de06
e9629ef
237de06
e9629ef
 
237de06
 
e9629ef
 
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
e9629ef
d8aafaa
 
 
e9629ef
 
d8aafaa
e9629ef
 
d8aafaa
e9629ef
d8aafaa
 
 
 
 
e9629ef
d8aafaa
e9629ef
 
d8aafaa
 
 
 
 
 
 
 
 
 
e9629ef
 
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
 
e9629ef
 
237de06
e9629ef
 
237de06
 
 
 
 
 
 
 
d8aafaa
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
e9629ef
d8aafaa
 
 
e9629ef
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from PIL import Image
from os import path, listdir

import hydra
import numpy as np
import torch
from torch.utils.data import Dataset
from loguru import logger
from tqdm.rich import tqdm
import diskcache as dc
from typing import Union
from drawer import draw_bboxes
from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp


class YoloDataset(Dataset):
    def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
        phase_name = dataset_cfg.get(phase, phase)
        self.image_size = image_size

        self.transform = transform
        self.transform.get_more_data = self.get_more_data
        self.transform.image_size = self.image_size
        self.data = self.load_data(dataset_cfg.path, phase_name)

    def load_data(self, dataset_path, phase_name):
        """
        Loads data from a cache or generates a new cache for a specific dataset phase.

        Parameters:
            dataset_path (str): The root path to the dataset directory.
            phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

        Returns:
            dict: The loaded data from the cache for the specified phase.
        """
        cache_path = path.join(dataset_path, ".cache")
        cache = dc.Cache(cache_path)
        data = cache.get(phase_name)

        if data is None:
            logger.info("Generating {} cache", phase_name)
            images_path = path.join(dataset_path, phase_name, "images")
            labels_path = path.join(dataset_path, phase_name, "labels")
            data = self.filter_data(images_path, labels_path)
            cache[phase_name] = data

        cache.close()
        logger.info("Loaded {} cache", phase_name)
        data = cache[phase_name]
        return data

    def filter_data(self, images_path: str, labels_path: str) -> list:
        """
        Filters and collects dataset information by pairing images with their corresponding labels.

        Parameters:
            images_path (str): Path to the directory containing image files.
            labels_path (str): Path to the directory containing label files.

        Returns:
            list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
        """
        data = []
        valid_inputs = 0
        images_list = sorted(listdir(images_path))
        for image_name in tqdm(images_list, desc="Filtering data"):
            if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue

            img_path = path.join(images_path, image_name)
            base_name, _ = path.splitext(image_name)
            label_path = path.join(labels_path, f"{base_name}.txt")

            if path.isfile(label_path):
                labels = self.load_valid_labels(label_path)
                if labels is not None:
                    data.append((img_path, labels))
                    valid_inputs += 1

        logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
        return data

    def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
        """
        Loads and validates bounding box data is [0, 1] from a label file.

        Parameters:
            label_path (str): The filepath to the label file containing bounding box data.

        Returns:
            torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
        """
        bboxes = []
        with open(label_path, "r") as file:
            for line in file:
                parts = list(map(float, line.strip().split()))
                cls = parts[0]
                points = np.array(parts[1:]).reshape(-1, 2)
                valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
                if valid_points.size > 1:
                    bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
                    bboxes.append(bbox)

        if bboxes:
            return torch.stack(bboxes)
        else:
            logger.warning("No valid BBox in {}", label_path)
            return None

    def get_data(self, idx):
        img_path, bboxes = self.data[idx]
        img = Image.open(img_path).convert("RGB")
        return img, bboxes

    def get_more_data(self, num: int = 1):
        indices = torch.randint(0, len(self), (num,))
        return [self.get_data(idx) for idx in indices]

    def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
        img, bboxes = self.get_data(idx)
        if self.transform:
            img, bboxes = self.transform(img, bboxes)
        return img, bboxes

    def __len__(self) -> int:
        return len(self.data)


@hydra.main(config_path="../config", config_name="config", version_base=None)
def main(cfg):
    transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
    dataset = YoloDataset(cfg.data, transform=transform)
    draw_bboxes(*dataset[0])


if __name__ == "__main__":
    import sys

    sys.path.append("./")
    from tools.log_helper import custom_logger

    custom_logger()
    main()