from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path from typing import Generic, TypeVar, List, Iterator, Optional, Any import logging from multiprocessing import Pool, cpu_count from itertools import islice import torch # Ensure torch is imported for CUDA operations T = TypeVar('T') # Input type R = TypeVar('R') # Result type @dataclass class BatchOptions: """Common options for batch processing tasks""" batch_size: int = 16 num_workers: Optional[int] = None device: str = "cuda" # cuda or cpu debug: bool = False skip_existing: bool = True recursive: bool = False supported_extensions: set[str] = field( default_factory=lambda: {'.png', '.jpg', '.jpeg', '.webp', '.jxl'} ) gen_threshold: float = 0.35 # General threshold for tagging char_threshold: float = 0.75 # Character threshold for tagging class BatchProcessor(Generic[T, R], ABC): """Base class for batch processing operations""" def __init__(self, opts: BatchOptions): self.opts = opts self._setup_logging() self.device = self._setup_device() def _setup_logging(self) -> None: level = logging.DEBUG if self.opts.debug else logging.INFO logging.basicConfig(level=level) self.logger = logging.getLogger(self.__class__.__name__) def _setup_device(self) -> str: if self.opts.device == "cuda" and not self._cuda_available(): self.logger.warning("CUDA requested but not available. Falling back to CPU.") return "cpu" return self.opts.device def _cuda_available(self) -> bool: try: return torch.cuda.is_available() except ImportError: return False @abstractmethod def process_item(self, item: T) -> R: """Process a single item""" pass @abstractmethod def should_process_item(self, item: T) -> bool: """Determine if an item should be processed""" pass def create_batch_iterator(self, items: Iterator[T]) -> Iterator[List[T]]: """Create batches from an iterator""" batch = [] for item in items: if self.should_process_item(item): batch.append(item) if len(batch) >= self.opts.batch_size: yield batch batch = [] if batch: yield batch def process_batch(self, batch: List[T]) -> List[R]: """Process a batch of items""" results = [] for item in batch: try: result = self.process_item(item) results.append(result) except Exception as e: self.logger.error(f"Error processing item: {e}") return results def process_all(self, items: Iterator[T], parallel: bool = True) -> Iterator[R]: """Process all items, optionally in parallel""" batches = list(self.create_batch_iterator(items)) if not batches: self.logger.warning("No items to process") return if parallel and self.opts.num_workers != 1: num_workers = self.opts.num_workers or max(1, cpu_count() // 2) self.logger.info(f"Processing {len(batches)} batches using {num_workers} workers") with Pool(num_workers) as pool: for batch_results in pool.imap(self.process_batch, batches): yield from batch_results else: self.logger.info(f"Processing {len(batches)} batches sequentially") for batch in batches: yield from self.process_batch(batch) def __del__(self): """Clean up CUDA resources if necessary""" if hasattr(self, 'device') and self.device == 'cuda': torch.cuda.empty_cache()