| |
| """ |
| Improved PPE Compliance Detection Training Script v2 |
| Fixed: Added config='full' for keremberke dataset |
| Combines multiple datasets for better coverage: |
| 1. 51ddhesh/PPE_Detection (~10K images, 6 PPE classes, YOLO format) |
| 2. keremberke/construction-safety-object-detection (398 images, 17 classes incl. violations) |
| |
| Trains YOLOv8s on combined data. |
| """ |
|
|
| import os |
| import sys |
| import zipfile |
| import shutil |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download, HfApi |
| from datasets import load_dataset |
| from PIL import Image |
| import yaml |
|
|
| |
| HF_USERNAME = "baskarmother" |
| MODEL_ID = "yolov8s-ppe-construction-v2" |
| DATASET_DIR = Path("/app/combined_ppe_dataset") |
| EPOCHS = 150 |
| IMG_SIZE = 640 |
| BATCH = 16 |
| DEVICE = "0" |
|
|
| |
| UNIFIED_CLASSES = [ |
| "person", |
| "helmet", |
| "vest", |
| "mask", |
| "gloves", |
| "safety_shoe", |
| "goggles", |
| "no_helmet", |
| "no_mask", |
| "no_vest", |
| "head", |
| "barricade", |
| "dumpster", |
| "excavators", |
| "safety_net", |
| "dump_truck", |
| "truck", |
| "wheel_loader", |
| ] |
|
|
|
|
| def download_ppe_dataset(): |
| """Download 51ddhesh/PPE_Detection ZIP and extract.""" |
| print("[1/5] Downloading 51ddhesh/PPE_Detection dataset...") |
| zip_path = hf_hub_download( |
| repo_id="51ddhesh/PPE_Detection", |
| filename="PPE.zip", |
| repo_type="dataset", |
| cache_dir="/app/hf_cache", |
| local_dir="/app/downloads", |
| ) |
| extract_dir = Path("/app/downloads/ppe_dataset") |
| extract_dir.mkdir(parents=True, exist_ok=True) |
| with zipfile.ZipFile(zip_path, 'r') as zf: |
| zf.extractall(extract_dir) |
| print(f" Extracted to {extract_dir}") |
| return extract_dir |
|
|
|
|
| def load_keremberke_dataset(): |
| """Load keremberke construction-safety-object-detection.""" |
| print("[2/5] Loading keremberke/construction-safety-object-detection...") |
| ds = load_dataset("keremberke/construction-safety-object-detection", "full") |
| print(f" Splits: {list(ds.keys())}") |
| return ds |
|
|
|
|
| def convert_keremberke_to_yolo(ds, output_dir: Path): |
| """Convert keremberke COCO-style dataset to YOLO format.""" |
| print("[3/5] Converting keremberke dataset to YOLO format...") |
| class_names = ds["train"].features["objects"].feature["category"].names |
| print(f" Classes: {class_names}") |
|
|
| class_map = { |
| "person": 0, |
| "hardhat": 1, |
| "mask": 3, |
| "no-hardhat": 7, |
| "no-mask": 8, |
| "no-safety vest": 9, |
| "gloves": 4, |
| "safety shoes": 5, |
| "safety vest": 2, |
| "barricade": 11, |
| "dumpster": 12, |
| "excavators": 13, |
| "safety net": 14, |
| "dump truck": 15, |
| "mini-van": 0, |
| "truck": 16, |
| "wheel loader": 17, |
| } |
|
|
| for split in ["train", "valid", "test"]: |
| if split not in ds: |
| continue |
| images_dir = output_dir / split / "images" |
| labels_dir = output_dir / split / "labels" |
| images_dir.mkdir(parents=True, exist_ok=True) |
| labels_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for i, example in enumerate(ds[split]): |
| img = example["image"] |
| img_filename = f"keremberke_{split}_{i:05d}.jpg" |
| img_path = images_dir / img_filename |
| img.save(img_path) |
|
|
| width, height = img.size |
| objects = example["objects"] |
| bboxes = objects["bbox"] |
| categories = objects["category"] |
|
|
| label_filename = img_filename.replace(".jpg", ".txt") |
| label_path = labels_dir / label_filename |
|
|
| with open(label_path, "w") as f: |
| for bbox, cat in zip(bboxes, categories): |
| class_name = class_names[cat] |
| if class_name not in class_map: |
| continue |
| unified_idx = class_map[class_name] |
|
|
| x, y, w, h = bbox |
| x_center = (x + w / 2) / width |
| y_center = (y + h / 2) / height |
| norm_w = w / width |
| norm_h = h / height |
|
|
| x_center = max(0, min(1, x_center)) |
| y_center = max(0, min(1, y_center)) |
| norm_w = max(0, min(1, norm_w)) |
| norm_h = max(0, min(1, norm_h)) |
|
|
| f.write(f"{unified_idx} {x_center:.6f} {y_center:.6f} {norm_w:.6f} {norm_h:.6f}\n") |
|
|
| print(f" Converted keremberke dataset to {output_dir}") |
|
|
|
|
| def merge_datasets(ppe_extract_dir: Path, keremberke_dir: Path, output_dir: Path): |
| """Merge both datasets into unified YOLO structure.""" |
| print("[4/5] Merging datasets...") |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| ppe_dir = None |
| for candidate in [ppe_extract_dir / "PPE", ppe_extract_dir / "ppe", ppe_extract_dir]: |
| if (candidate / "train" / "images").exists(): |
| ppe_dir = candidate |
| break |
|
|
| if ppe_dir is None: |
| print(" ERROR: Could not find PPE dataset structure") |
| print(f" Contents: {list(ppe_extract_dir.iterdir())}") |
| sys.exit(1) |
|
|
| print(f" Found PPE dataset at: {ppe_dir}") |
|
|
| ppe_class_map = { |
| 0: 2, |
| 1: 5, |
| 2: 3, |
| 3: 1, |
| 4: 6, |
| 5: 4, |
| } |
|
|
| for split in ["train", "valid", "test"]: |
| out_images = output_dir / split / "images" |
| out_labels = output_dir / split / "labels" |
| out_images.mkdir(parents=True, exist_ok=True) |
| out_labels.mkdir(parents=True, exist_ok=True) |
|
|
| ppe_images = ppe_dir / split / "images" |
| ppe_labels = ppe_dir / split / "labels" |
|
|
| if ppe_images.exists(): |
| for img_file in sorted(ppe_images.iterdir()): |
| if img_file.suffix.lower() not in [".jpg", ".jpeg", ".png"]: |
| continue |
| shutil.copy2(img_file, out_images / f"ppe_{img_file.name}") |
|
|
| label_file = ppe_labels / f"{img_file.stem}.txt" |
| if label_file.exists(): |
| with open(label_file) as f: |
| lines = f.readlines() |
| remapped = [] |
| for line in lines: |
| parts = line.strip().split() |
| if len(parts) < 5: |
| continue |
| src_cls = int(parts[0]) |
| if src_cls in ppe_class_map: |
| unified_cls = ppe_class_map[src_cls] |
| remapped.append(f"{unified_cls} {' '.join(parts[1:])}\n") |
|
|
| out_label = out_labels / f"ppe_{img_file.stem}.txt" |
| with open(out_label, "w") as f: |
| f.writelines(remapped) |
|
|
| k_images = keremberke_dir / split / "images" |
| k_labels = keremberke_dir / split / "labels" |
|
|
| if k_images.exists(): |
| for img_file in sorted(k_images.iterdir()): |
| shutil.copy2(img_file, out_images / img_file.name) |
| for label_file in sorted(k_labels.iterdir()): |
| shutil.copy2(label_file, out_labels / label_file.name) |
|
|
| data_yaml = { |
| "path": str(output_dir.absolute()), |
| "train": "train/images", |
| "val": "valid/images", |
| "test": "test/images", |
| "names": {i: name for i, name in enumerate(UNIFIED_CLASSES)}, |
| "nc": len(UNIFIED_CLASSES), |
| } |
|
|
| with open(output_dir / "data.yaml", "w") as f: |
| yaml.dump(data_yaml, f, default_flow_style=False) |
|
|
| print(f" Merged dataset at {output_dir}") |
| for split in ["train", "valid", "test"]: |
| img_count = len(list((output_dir / split / "images").glob("*"))) |
| print(f" {split}: {img_count} images") |
|
|
|
|
| def train_model(data_yaml_path: Path): |
| print("[5/5] Training YOLOv8s...") |
| from ultralytics import YOLO |
|
|
| model = YOLO("yolov8s.pt") |
|
|
| results = model.train( |
| data=str(data_yaml_path), |
| epochs=EPOCHS, |
| imgsz=IMG_SIZE, |
| batch=BATCH, |
| device=DEVICE, |
| patience=30, |
| project="/app/runs", |
| name="ppe_improved", |
| exist_ok=True, |
| pretrained=True, |
| optimizer="SGD", |
| lr0=0.01, |
| lrf=0.01, |
| momentum=0.9, |
| weight_decay=0.0005, |
| augment=True, |
| mosaic=1.0, |
| hsv_h=0.015, |
| hsv_s=0.7, |
| hsv_v=0.4, |
| degrees=5.0, |
| translate=0.1, |
| scale=0.5, |
| shear=2.0, |
| perspective=0.0, |
| flipud=0.0, |
| fliplr=0.5, |
| ) |
|
|
| print(" Training complete!") |
| print(f" Best model: {results.best}") |
| return results |
|
|
|
|
| def push_to_hub(best_model_path: Path): |
| print("Pushing model to HuggingFace Hub...") |
| api = HfApi() |
| repo_id = f"{HF_USERNAME}/{MODEL_ID}" |
|
|
| try: |
| api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) |
| except Exception as e: |
| print(f" Repo creation info: {e}") |
|
|
| api.upload_file( |
| path_or_fileobj=str(best_model_path), |
| path_in_repo="best.pt", |
| repo_id=repo_id, |
| repo_type="model", |
| ) |
|
|
| readme = f"""--- |
| license: cc-by-4.0 |
| library_name: ultralytics |
| tags: |
| - object-detection |
| - ppe |
| - construction-safety |
| - yolov8 |
| --- |
| |
| # {MODEL_ID} |
| |
| Improved PPE Compliance Detection Model for Construction Sites (v2) |
| |
| ## Description |
| This is an improved YOLOv8s model trained on a combined dataset of: |
| - **51ddhesh/PPE_Detection** (~10K images, 6 PPE classes) |
| - **keremberke/construction-safety-object-detection** (398 images, violation classes) |
| |
| ## Classes ({len(UNIFIED_CLASSES)}) |
| {chr(10).join(f"- {i}: {name}" for i, name in enumerate(UNIFIED_CLASSES))} |
| |
| ## Usage |
| ```python |
| from ultralytics import YOLO |
| model = YOLO("hf://{repo_id}/best.pt") |
| results = model.predict("image.jpg") |
| ``` |
| |
| ## Training Details |
| - Base Model: YOLOv8s |
| - Epochs: {EPOCHS} |
| - Image Size: {IMG_SIZE}x{IMG_SIZE} |
| - Batch Size: {BATCH} |
| - Augmentations: Mosaic, HSV, scale, shear, flip |
| |
| ## Compliance Detection |
| The model detects both PPE presence AND absence: |
| - `no_helmet`, `no_mask`, `no_vest` = violation classes |
| - `helmet`, `mask`, `vest` = compliance classes |
| """ |
|
|
| api.upload_file( |
| path_or_fileobj=readme.encode(), |
| path_in_repo="README.md", |
| repo_id=repo_id, |
| repo_type="model", |
| ) |
|
|
| print(f" Model pushed to https://huggingface.co/{repo_id}") |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("IMPROVED PPE DETECTION TRAINING v2") |
| print("=" * 60) |
|
|
| ppe_dir = download_ppe_dataset() |
| keremberke_ds = load_keremberke_dataset() |
| keremberke_yolo_dir = Path("/app/keremberke_yolo") |
| convert_keremberke_to_yolo(keremberke_ds, keremberke_yolo_dir) |
| DATASET_DIR.mkdir(parents=True, exist_ok=True) |
| merge_datasets(ppe_dir, keremberke_yolo_dir, DATASET_DIR) |
| data_yaml = DATASET_DIR / "data.yaml" |
| results = train_model(data_yaml) |
|
|
| best_model = Path("/app/runs/ppe_improved/weights/best.pt") |
| if best_model.exists(): |
| push_to_hub(best_model) |
| else: |
| print(f" WARNING: Best model not found at {best_model}") |
| for pt_file in Path("/app/runs").rglob("best.pt"): |
| print(f" Found: {pt_file}") |
| push_to_hub(pt_file) |
| break |
|
|
| print("=" * 60) |
| print("DONE!") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|