Spaces:
Sleeping
Sleeping
🔀 [Merge] branch 'DDP_BUGS' of github.com:WongKinYiu/yolov9mit into DDP_BUGS
Browse files- yolo/tools/data_loader.py +17 -14
- yolo/tools/solver.py +3 -4
yolo/tools/data_loader.py
CHANGED
|
@@ -23,6 +23,20 @@ from yolo.utils.dataset_utils import (
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class YoloDataset(Dataset):
|
| 27 |
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
|
| 28 |
augment_cfg = data_cfg.data_augment
|
|
@@ -32,19 +46,8 @@ class YoloDataset(Dataset):
|
|
| 32 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
| 33 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
| 34 |
self.transform.get_more_data = self.get_more_data
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def tensorlize(self, data):
|
| 38 |
-
img_paths, bboxes = zip(*data)
|
| 39 |
-
max_box = max(bbox.size(0) for bbox in bboxes)
|
| 40 |
-
padded_bbox_list = []
|
| 41 |
-
for bbox in bboxes:
|
| 42 |
-
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
| 43 |
-
padding[: bbox.size(0)] = bbox
|
| 44 |
-
padded_bbox_list.append(padding)
|
| 45 |
-
bboxes = torch.stack(padded_bbox_list)
|
| 46 |
-
img_paths = np.array(img_paths)
|
| 47 |
-
return img_paths, bboxes
|
| 48 |
|
| 49 |
def load_data(self, dataset_path: Path, phase_name: str):
|
| 50 |
"""
|
|
@@ -147,7 +150,7 @@ class YoloDataset(Dataset):
|
|
| 147 |
img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
|
| 148 |
valid_mask = bboxes[:, 0] != -1
|
| 149 |
img = Image.open(img_path).convert("RGB")
|
| 150 |
-
return img, bboxes[valid_mask], img_path
|
| 151 |
|
| 152 |
def get_more_data(self, num: int = 1):
|
| 153 |
indices = torch.randint(0, len(self), (num,))
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
+
def tensorlize(data):
|
| 27 |
+
# TODO Move Tensorlize to helper
|
| 28 |
+
img_paths, bboxes = zip(*data)
|
| 29 |
+
max_box = max(bbox.size(0) for bbox in bboxes)
|
| 30 |
+
padded_bbox_list = []
|
| 31 |
+
for bbox in bboxes:
|
| 32 |
+
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
| 33 |
+
padding[: bbox.size(0)] = bbox
|
| 34 |
+
padded_bbox_list.append(padding)
|
| 35 |
+
bboxes = np.stack(padded_bbox_list)
|
| 36 |
+
img_paths = np.array(img_paths)
|
| 37 |
+
return img_paths, bboxes
|
| 38 |
+
|
| 39 |
+
|
| 40 |
class YoloDataset(Dataset):
|
| 41 |
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
|
| 42 |
augment_cfg = data_cfg.data_augment
|
|
|
|
| 46 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
| 47 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
| 48 |
self.transform.get_more_data = self.get_more_data
|
| 49 |
+
img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
| 50 |
+
self.img_paths, self.bboxes = img_paths, bboxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
def load_data(self, dataset_path: Path, phase_name: str):
|
| 53 |
"""
|
|
|
|
| 150 |
img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
|
| 151 |
valid_mask = bboxes[:, 0] != -1
|
| 152 |
img = Image.open(img_path).convert("RGB")
|
| 153 |
+
return img, torch.from_numpy(bboxes[valid_mask]), img_path
|
| 154 |
|
| 155 |
def get_more_data(self, num: int = 1):
|
| 156 |
indices = torch.randint(0, len(self), (num,))
|
yolo/tools/solver.py
CHANGED
|
@@ -147,7 +147,7 @@ class ModelTrainer:
|
|
| 147 |
self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
|
| 148 |
|
| 149 |
mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
|
| 150 |
-
if
|
| 151 |
self.save_checkpoint(epoch_idx=epoch_idx)
|
| 152 |
# TODO: save model if result are better than before
|
| 153 |
self.progress.finish_train()
|
|
@@ -256,9 +256,8 @@ class ModelValidator:
|
|
| 256 |
|
| 257 |
with open(self.json_path, "w") as f:
|
| 258 |
predict_json = collect_prediction(predict_json, self.progress.local_rank)
|
| 259 |
-
if self.progress.local_rank
|
| 260 |
-
|
| 261 |
-
json.dump(predict_json, f)
|
| 262 |
if hasattr(self, "coco_gt"):
|
| 263 |
self.progress.start_pycocotools()
|
| 264 |
result = calculate_ap(self.coco_gt, predict_json)
|
|
|
|
| 147 |
self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
|
| 148 |
|
| 149 |
mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
|
| 150 |
+
if self.good_epoch(mAPs):
|
| 151 |
self.save_checkpoint(epoch_idx=epoch_idx)
|
| 152 |
# TODO: save model if result are better than before
|
| 153 |
self.progress.finish_train()
|
|
|
|
| 256 |
|
| 257 |
with open(self.json_path, "w") as f:
|
| 258 |
predict_json = collect_prediction(predict_json, self.progress.local_rank)
|
| 259 |
+
if self.progress.local_rank == 0:
|
| 260 |
+
json.dump(predict_json, f)
|
|
|
|
| 261 |
if hasattr(self, "coco_gt"):
|
| 262 |
self.progress.start_pycocotools()
|
| 263 |
result = calculate_ap(self.coco_gt, predict_json)
|