| | from pathlib import Path |
| |
|
| | class_name = 'strawberry' |
| |
|
| | def rewrite_single_class_data_yaml(dataset_dir, class_name='strawberry'): |
| | dataset_dir = Path(dataset_dir) |
| | data_yaml_path = dataset_dir / 'data.yaml' |
| | if not data_yaml_path.exists(): |
| | print('⚠️ data.yaml not found, skipping rewrite.') |
| | return |
| |
|
| | train_path = dataset_dir / 'train' / 'images' |
| | val_path = dataset_dir / 'valid' / 'images' |
| | test_path = dataset_dir / 'test' / 'images' |
| |
|
| | content_lines = [ |
| | '# Strawberry-only dataset', |
| | f'train: {train_path}', |
| | f'val: {val_path}', |
| | f"test: {test_path if test_path.exists() else ''}", |
| | '', |
| | 'nc: 1', |
| | f"names: ['{class_name}']", |
| | ] |
| |
|
| | data_yaml_path.write_text('\n'.join(content_lines) + '\n') |
| | print(f"✅ data.yaml updated for single-class training ({class_name}).") |
| |
|
| |
|
| | def enforce_single_class_dataset(dataset_dir, target_class=0, class_name='strawberry'): |
| | dataset_dir = Path(dataset_dir) |
| | stats = {'split_kept': {}, 'labels_removed': 0, 'images_removed': 0} |
| | allowed_ext = ['.jpg', '.jpeg', '.png'] |
| |
|
| | for split in ['train', 'valid', 'test']: |
| | labels_dir = dataset_dir / split / 'labels' |
| | images_dir = dataset_dir / split / 'images' |
| | if not labels_dir.exists(): |
| | continue |
| |
|
| | kept = 0 |
| | for label_path in labels_dir.glob('*.txt'): |
| | kept_lines = [] |
| | for raw_line in label_path.read_text().splitlines(): |
| | line = raw_line.strip() |
| | if not line: |
| | continue |
| | parts = line.split() |
| | if not parts: |
| | continue |
| | try: |
| | class_id = int(parts[0]) |
| | except ValueError: |
| | continue |
| | if class_id == target_class: |
| | kept_lines.append(line) |
| |
|
| | if kept_lines: |
| | label_path.write_text('\n'.join(kept_lines) + '\n') |
| | kept += len(kept_lines) |
| | else: |
| | label_path.unlink() |
| | stats['labels_removed'] += 1 |
| | for ext in allowed_ext: |
| | candidate = images_dir / f"{label_path.stem}{ext}" |
| | if candidate.exists(): |
| | candidate.unlink() |
| | stats['images_removed'] += 1 |
| | break |
| |
|
| | stats['split_kept'][split] = kept |
| |
|
| | rewrite_single_class_data_yaml(dataset_dir, class_name) |
| |
|
| | print('\n🍓 Strawberry-only filtering summary:') |
| | for split, count in stats['split_kept'].items(): |
| | print(f" {split}: {count} annotations kept") |
| | print(f" Label files removed: {stats['labels_removed']}") |
| | print(f" Images removed (non-strawberry or empty labels): {stats['images_removed']}") |
| |
|
| | return stats |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | dataset_path = "path/to/your/dataset" |
| |
|
| | if dataset_path and Path(dataset_path).exists(): |
| | strawberry_stats = enforce_single_class_dataset(dataset_path, target_class=0, class_name=class_name) |
| | else: |
| | print("Please set a valid dataset_path variable.") |