paint_defect_detector / src\prepare_data.py
therealestcoder's picture
Upload src\prepare_data.py with huggingface_hub
0d3fe75 verified
"""Подготовка датасета: нарезка панелей кузова на патчи и разбиение train/val.
Из исходных фото 4000x1846 (плоские панели — образцы окраски) автоматически
вырезается область панели (по яркости/контрасту), затем нарезаются перекрытые
патчи 512x512. Дефектные образцы → класс 1, образцы без дефектов → класс 0.
Запуск: python -m src.prepare_data
"""
from __future__ import annotations
import shutil
import random
from pathlib import Path
import cv2
import numpy as np
from . import config as C
def imread_unicode(path: Path) -> np.ndarray | None:
"""cv2.imread не понимает Cyrillic-пути на Windows — обходим через np.fromfile."""
try:
data = np.fromfile(str(path), dtype=np.uint8)
if data.size == 0:
return None
return cv2.imdecode(data, cv2.IMREAD_COLOR)
except Exception:
return None
def imwrite_unicode(path: Path, img: np.ndarray, params=None) -> bool:
ext = path.suffix or ".jpg"
ok, buf = cv2.imencode(ext, img, params or [])
if not ok:
return False
buf.tofile(str(path))
return True
def crop_panel(bgr: np.ndarray) -> np.ndarray:
"""Вырезает прямоугольник панели из светлого фона.
На исходных фото панель окраски лежит на белом столе. Бинаризуем изображение
по Оцу, берём наибольший контур и вырезаем его bounding box.
"""
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (9, 9), 0)
# Панель темнее белого фона -> инвертированный Оцу
_, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((15, 15), np.uint8))
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return bgr
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
# отступаем внутрь, чтобы не зацепить край/тень/наклейку
pad = int(0.04 * min(w, h))
x1, y1 = max(0, x + pad), max(0, y + pad)
x2, y2 = min(bgr.shape[1], x + w - pad), min(bgr.shape[0], y + h - pad)
if x2 - x1 < 200 or y2 - y1 < 200:
return bgr
return bgr[y1:y2, x1:x2]
def slice_patches(panel: np.ndarray, size: int, stride: int) -> list[np.ndarray]:
"""Нарезает панель на квадратные патчи с заданным шагом."""
h, w = panel.shape[:2]
if h < size or w < size:
# маленькая панель: один центральный ресайз
return [cv2.resize(panel, (size, size))]
patches = []
ys = list(range(0, h - size + 1, stride))
xs = list(range(0, w - size + 1, stride))
if ys[-1] != h - size:
ys.append(h - size)
if xs[-1] != w - size:
xs.append(w - size)
for y in ys:
for x in xs:
patches.append(panel[y:y + size, x:x + size])
return patches
def main(val_ratio: float = 0.2, seed: int = C.SEED) -> None:
random.seed(seed)
src_pairs = [
(C.SRC_DEFECT, "defect"),
(C.SRC_CLEAN, "clean"),
]
# пересобираем выходные каталоги
if C.DATA_PATCHES.exists():
shutil.rmtree(C.DATA_PATCHES)
for split in ("train", "val"):
for cls in ("defect", "clean"):
(C.DATA_PATCHES / split / cls).mkdir(parents=True, exist_ok=True)
# также продублируем оригиналы в data/raw для удобства
C.DATA_RAW.mkdir(parents=True, exist_ok=True)
for src_dir, cls in src_pairs:
out = C.DATA_RAW / cls
out.mkdir(exist_ok=True)
for f in src_dir.iterdir():
if f.suffix.lower() in {".jpg", ".jpeg", ".png"}:
shutil.copy2(f, out / f.name)
total = {"train": 0, "val": 0}
for src_dir, cls in src_pairs:
files = [f for f in src_dir.iterdir()
if f.suffix.lower() in {".jpg", ".jpeg", ".png"}]
random.shuffle(files)
n_val = max(1, int(len(files) * val_ratio))
val_files = set(files[:n_val])
for f in files:
split = "val" if f in val_files else "train"
img = imread_unicode(f)
if img is None:
print(f"[skip] не удалось прочитать {f}")
continue
panel = crop_panel(img) if C.PANEL_CROP else img
patches = slice_patches(panel, C.PATCH_SIZE, C.PATCH_STRIDE)
stem = f.stem
for i, p in enumerate(patches):
out_path = C.DATA_PATCHES / split / cls / f"{stem}_{i:03d}.jpg"
imwrite_unicode(out_path, p, [cv2.IMWRITE_JPEG_QUALITY, 92])
total[split] += 1
print(f"[{split}/{cls}] {f.name}: {len(patches)} патчей")
print(f"\nИтого патчей: train={total['train']} val={total['val']}")
print(f"Готовый датасет: {C.DATA_PATCHES}")
if __name__ == "__main__":
main()