| |
|
|
| from dataclasses import make_dataclass |
| from functools import lru_cache |
| from typing import Any, Optional |
| import torch |
|
|
|
|
| @lru_cache(maxsize=None) |
| def decorate_cse_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: |
| """ |
| Create a new output class from an existing one by adding new attributes |
| related to confidence estimation: |
| - coarse_segm_confidence (tensor) |
| |
| Details on confidence estimation parameters can be found in: |
| N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning |
| Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 |
| A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 |
| |
| The new class inherits the provided `BasePredictorOutput` class, |
| it's name is composed of the name of the provided class and |
| "WithConfidences" suffix. |
| |
| Args: |
| BasePredictorOutput (type): output type to which confidence data |
| is to be added, assumed to be a dataclass |
| Return: |
| New dataclass derived from the provided one that has attributes |
| for confidence estimation |
| """ |
|
|
| PredictorOutput = make_dataclass( |
| BasePredictorOutput.__name__ + "WithConfidences", |
| fields=[ |
| ("coarse_segm_confidence", Optional[torch.Tensor], None), |
| ], |
| bases=(BasePredictorOutput,), |
| ) |
|
|
| |
|
|
| def slice_if_not_none(data, item): |
| if data is None: |
| return None |
| if isinstance(item, int): |
| return data[item].unsqueeze(0) |
| return data[item] |
|
|
| def PredictorOutput_getitem(self, item): |
| PredictorOutput = type(self) |
| base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) |
| return PredictorOutput( |
| **base_predictor_output_sliced.__dict__, |
| coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), |
| ) |
|
|
| PredictorOutput.__getitem__ = PredictorOutput_getitem |
|
|
| def PredictorOutput_to(self, device: torch.device): |
| """ |
| Transfers all tensors to the given device |
| """ |
| PredictorOutput = type(self) |
| base_predictor_output_to = super(PredictorOutput, self).to(device) |
|
|
| def to_device_if_tensor(var: Any): |
| if isinstance(var, torch.Tensor): |
| return var.to(device) |
| return var |
|
|
| return PredictorOutput( |
| **base_predictor_output_to.__dict__, |
| coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), |
| ) |
|
|
| PredictorOutput.to = PredictorOutput_to |
| return PredictorOutput |
|
|