import copy from typing import Iterator, Optional from tqdm.auto import tqdm from ..action import BaseAction from ..export import BaseExporter from ..model import ImageItem from ..utils import task_ctx, get_task_names class BaseDataSource: def _iter(self) -> Iterator[ImageItem]: raise NotImplementedError # pragma: no cover def _iter_from(self) -> Iterator[ImageItem]: yield from self._iter() def __iter__(self) -> Iterator[ImageItem]: yield from self._iter_from() def __or__(self, other): from .compose import ParallelDataSource if isinstance(self, ParallelDataSource): if isinstance(other, ParallelDataSource): return ParallelDataSource(*self.sources, *other.sources) else: return ParallelDataSource(*self.sources, other) else: if isinstance(other, ParallelDataSource): return ParallelDataSource(self, *other.sources) else: return ParallelDataSource(self, other) def __add__(self, other): from .compose import ComposedDataSource if isinstance(self, ComposedDataSource): if isinstance(other, ComposedDataSource): return ComposedDataSource(*self.sources, *other.sources) else: return ComposedDataSource(*self.sources, other) else: if isinstance(other, ComposedDataSource): return ComposedDataSource(self, *other.sources) else: return ComposedDataSource(self, other) def __getitem__(self, item): from ..action import SliceSelectAction if isinstance(item, slice): return self.attach(SliceSelectAction(item.start, item.stop, item.step)) else: raise TypeError(f'Data source\'s getitem only accept slices, but {item!r} found.') def attach(self, *actions: BaseAction) -> 'AttachedDataSource': return AttachedDataSource(self, *actions) def export(self, exporter: BaseExporter, name: Optional[str] = None): exporter = copy.deepcopy(exporter) exporter.reset() with task_ctx(name): return exporter.export_from(iter(self)) class RootDataSource(BaseDataSource): def _iter(self) -> Iterator[ImageItem]: raise NotImplementedError # pragma: no cover def _iter_from(self) -> Iterator[ImageItem]: names = get_task_names() if names: desc = f'{self.__class__.__name__} - {".".join(names)}' else: desc = f'{self.__class__.__name__}' for item in tqdm(self._iter(), desc=desc): yield item class AttachedDataSource(BaseDataSource): def __init__(self, source: BaseDataSource, *actions: BaseAction): self.source = source self.actions = actions def _iter(self) -> Iterator[ImageItem]: t = self.source for action in self.actions: action = copy.deepcopy(action) action.reset() t = action.iter_from(t) yield from t class EmptySource(BaseDataSource): def _iter(self) -> Iterator[ImageItem]: yield from []