# Copyright (c) Facebook, Inc. and its affiliates. import random from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch from torch import nn SampledData = Any ModelOutput = Any def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: """ Group elements of an iterable by chunks of size `n`, e.g. grouper(range(9), 4) -> (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) """ it = iter(iterable) while True: values = [] for _ in range(n): try: value = next(it) except StopIteration: if values: values.extend([fillvalue] * (n - len(values))) yield tuple(values) return values.append(value) yield tuple(values) class ScoreBasedFilter: """ Filters entries in model output based on their scores Discards all entries with score less than the specified minimum """ def __init__(self, min_score: float = 0.8): self.min_score = min_score def __call__(self, model_output: ModelOutput) -> ModelOutput: for model_output_i in model_output: instances = model_output_i["instances"] if not instances.has("scores"): continue instances_filtered = instances[instances.scores >= self.min_score] model_output_i["instances"] = instances_filtered return model_output class InferenceBasedLoader: """ Data loader based on results inferred by a model. Consists of: - a data loader that provides batches of images - a model that is used to infer the results - a data sampler that converts inferred results to annotations """ def __init__( self, model: nn.Module, data_loader: Iterable[List[Dict[str, Any]]], data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None, data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None, shuffle: bool = True, batch_size: int = 4, inference_batch_size: int = 4, drop_last: bool = False, category_to_class_mapping: Optional[dict] = None, ): """ Constructor Args: model (torch.nn.Module): model used to produce data data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides dictionaries with "images" and "categories" fields to perform inference on data_sampler (Callable: ModelOutput -> SampledData): functor that produces annotation data from inference results; (optional, default: None) data_filter (Callable: ModelOutput -> ModelOutput): filter that selects model outputs for further processing (optional, default: None) shuffle (bool): if True, the input images get shuffled batch_size (int): batch size for the produced annotation data inference_batch_size (int): batch size for input images drop_last (bool): if True, drop the last batch if it is undersized category_to_class_mapping (dict): category to class mapping """ self.model = model self.model.eval() self.data_loader = data_loader self.data_sampler = data_sampler self.data_filter = data_filter self.shuffle = shuffle self.batch_size = batch_size self.inference_batch_size = inference_batch_size self.drop_last = drop_last if category_to_class_mapping is not None: self.category_to_class_mapping = category_to_class_mapping else: self.category_to_class_mapping = {} def __iter__(self) -> Iterator[List[SampledData]]: for batch in self.data_loader: # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]] # images_batch : Tensor[N, C, H, W] # image : Tensor[C, H, W] images_and_categories = [ {"image": image, "category": category} for element in batch for image, category in zip(element["images"], element["categories"]) ] if not images_and_categories: continue if self.shuffle: random.shuffle(images_and_categories) yield from self._produce_data(images_and_categories) # pyre-ignore[6] def _produce_data( self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]] ) -> Iterator[List[SampledData]]: """ Produce batches of data from images Args: images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]): list of images and corresponding categories to process Returns: Iterator over batches of data sampled from model outputs """ data_batches: List[SampledData] = [] category_to_class_mapping = self.category_to_class_mapping batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size) for batch in batched_images_and_categories: batch = [ { "image": image_and_category["image"].to(self.model.device), "category": image_and_category["category"], } for image_and_category in batch if image_and_category is not None ] if not batch: continue with torch.no_grad(): model_output = self.model(batch) for model_output_i, batch_i in zip(model_output, batch): assert len(batch_i["image"].shape) == 3 model_output_i["image"] = batch_i["image"] instance_class = category_to_class_mapping.get(batch_i["category"], 0) model_output_i["instances"].dataset_classes = torch.tensor( [instance_class] * len(model_output_i["instances"]) ) model_output_filtered = ( model_output if self.data_filter is None else self.data_filter(model_output) ) data = ( model_output_filtered if self.data_sampler is None else self.data_sampler(model_output_filtered) ) for data_i in data: if len(data_i["instances"]): data_batches.append(data_i) if len(data_batches) >= self.batch_size: yield data_batches[: self.batch_size] data_batches = data_batches[self.batch_size :] if not self.drop_last and data_batches: yield data_batches