File size: 1,006 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import importlib
import multiprocessing
from typing import Optional, Sequence, List
import numpy as np
from chromadb.api.types import URI, DataLoader, Image
from concurrent.futures import ThreadPoolExecutor


class ImageLoader(DataLoader[List[Optional[Image]]]):
    def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None:
        try:
            self._PILImage = importlib.import_module("PIL.Image")
            self._max_workers = max_workers
        except ImportError:
            raise ValueError(
                "The PIL python package is not installed. Please install it with `pip install pillow`"
            )

    def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
        return np.array(self._PILImage.open(uri)) if uri is not None else None

    def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
        with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
            return list(executor.map(self._load_image, uris))