Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, List, Optional | |
| from detectron2.structures import Instances | |
| ModelOutput = Dict[str, Any] | |
| SampledData = Dict[str, Any] | |
| class _Sampler: | |
| """ | |
| Sampler registry entry that contains: | |
| - src (str): source field to sample from (deleted after sampling) | |
| - dst (Optional[str]): destination field to sample to, if not None | |
| - func (Optional[Callable: Any -> Any]): function that performs sampling, | |
| if None, reference copy is performed | |
| """ | |
| src: str | |
| dst: Optional[str] | |
| func: Optional[Callable[[Any], Any]] | |
| class PredictionToGroundTruthSampler: | |
| """ | |
| Sampler implementation that converts predictions to GT using registered | |
| samplers for different fields of `Instances`. | |
| """ | |
| def __init__(self, dataset_name: str = ""): | |
| self.dataset_name = dataset_name | |
| self._samplers = {} | |
| self.register_sampler("pred_boxes", "gt_boxes", None) | |
| self.register_sampler("pred_classes", "gt_classes", None) | |
| # delete scores | |
| self.register_sampler("scores") | |
| def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]: | |
| """ | |
| Transform model output into ground truth data through sampling | |
| Args: | |
| model_output (Dict[str, Any]): model output | |
| Returns: | |
| Dict[str, Any]: sampled data | |
| """ | |
| for model_output_i in model_output: | |
| instances: Instances = model_output_i["instances"] | |
| # transform data in each field | |
| for _, sampler in self._samplers.items(): | |
| if not instances.has(sampler.src) or sampler.dst is None: | |
| continue | |
| if sampler.func is None: | |
| instances.set(sampler.dst, instances.get(sampler.src)) | |
| else: | |
| instances.set(sampler.dst, sampler.func(instances)) | |
| # delete model output data that was transformed | |
| for _, sampler in self._samplers.items(): | |
| if sampler.src != sampler.dst and instances.has(sampler.src): | |
| instances.remove(sampler.src) | |
| model_output_i["dataset"] = self.dataset_name | |
| return model_output | |
| def register_sampler( | |
| self, | |
| prediction_attr: str, | |
| gt_attr: Optional[str] = None, | |
| func: Optional[Callable[[Any], Any]] = None, | |
| ): | |
| """ | |
| Register sampler for a field | |
| Args: | |
| prediction_attr (str): field to replace with a sampled value | |
| gt_attr (Optional[str]): field to store the sampled value to, if not None | |
| func (Optional[Callable: Any -> Any]): sampler function | |
| """ | |
| self._samplers[(prediction_attr, gt_attr)] = _Sampler( | |
| src=prediction_attr, dst=gt_attr, func=func | |
| ) | |
| def remove_sampler( | |
| self, | |
| prediction_attr: str, | |
| gt_attr: Optional[str] = None, | |
| ): | |
| """ | |
| Remove sampler for a field | |
| Args: | |
| prediction_attr (str): field to replace with a sampled value | |
| gt_attr (Optional[str]): field to store the sampled value to, if not None | |
| """ | |
| assert (prediction_attr, gt_attr) in self._samplers | |
| del self._samplers[(prediction_attr, gt_attr)] | |