File size: 5,207 Bytes
e9629ef d8aafaa e9629ef d8aafaa 5ce12fa e9629ef 237de06 e9629ef 237de06 e9629ef 237de06 e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef 237de06 e9629ef 237de06 d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef d8aafaa e9629ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from PIL import Image
from os import path, listdir
import hydra
import numpy as np
import torch
from torch.utils.data import Dataset
from loguru import logger
from tqdm.rich import tqdm
import diskcache as dc
from typing import Union
from drawer import draw_bboxes
from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp
class YoloDataset(Dataset):
def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
phase_name = dataset_cfg.get(phase, phase)
self.image_size = image_size
self.transform = transform
self.transform.get_more_data = self.get_more_data
self.transform.image_size = self.image_size
self.data = self.load_data(dataset_cfg.path, phase_name)
def load_data(self, dataset_path, phase_name):
"""
Loads data from a cache or generates a new cache for a specific dataset phase.
Parameters:
dataset_path (str): The root path to the dataset directory.
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
Returns:
dict: The loaded data from the cache for the specified phase.
"""
cache_path = path.join(dataset_path, ".cache")
cache = dc.Cache(cache_path)
data = cache.get(phase_name)
if data is None:
logger.info("Generating {} cache", phase_name)
images_path = path.join(dataset_path, phase_name, "images")
labels_path = path.join(dataset_path, phase_name, "labels")
data = self.filter_data(images_path, labels_path)
cache[phase_name] = data
cache.close()
logger.info("Loaded {} cache", phase_name)
data = cache[phase_name]
return data
def filter_data(self, images_path: str, labels_path: str) -> list:
"""
Filters and collects dataset information by pairing images with their corresponding labels.
Parameters:
images_path (str): Path to the directory containing image files.
labels_path (str): Path to the directory containing label files.
Returns:
list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
"""
data = []
valid_inputs = 0
images_list = sorted(listdir(images_path))
for image_name in tqdm(images_list, desc="Filtering data"):
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
continue
img_path = path.join(images_path, image_name)
base_name, _ = path.splitext(image_name)
label_path = path.join(labels_path, f"{base_name}.txt")
if path.isfile(label_path):
labels = self.load_valid_labels(label_path)
if labels is not None:
data.append((img_path, labels))
valid_inputs += 1
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
return data
def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
"""
Loads and validates bounding box data is [0, 1] from a label file.
Parameters:
label_path (str): The filepath to the label file containing bounding box data.
Returns:
torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
"""
bboxes = []
with open(label_path, "r") as file:
for line in file:
parts = list(map(float, line.strip().split()))
cls = parts[0]
points = np.array(parts[1:]).reshape(-1, 2)
valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
if valid_points.size > 1:
bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
bboxes.append(bbox)
if bboxes:
return torch.stack(bboxes)
else:
logger.warning("No valid BBox in {}", label_path)
return None
def get_data(self, idx):
img_path, bboxes = self.data[idx]
img = Image.open(img_path).convert("RGB")
return img, bboxes
def get_more_data(self, num: int = 1):
indices = torch.randint(0, len(self), (num,))
return [self.get_data(idx) for idx in indices]
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
img, bboxes = self.get_data(idx)
if self.transform:
img, bboxes = self.transform(img, bboxes)
return img, bboxes
def __len__(self) -> int:
return len(self.data)
@hydra.main(config_path="../config", config_name="config", version_base=None)
def main(cfg):
transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
dataset = YoloDataset(cfg.data, transform=transform)
draw_bboxes(*dataset[0])
if __name__ == "__main__":
import sys
sys.path.append("./")
from tools.log_helper import custom_logger
custom_logger()
main()
|