|
|
|
|
|
from dataclasses import make_dataclass |
|
from functools import lru_cache |
|
from typing import Any, Optional |
|
import torch |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: |
|
""" |
|
Create a new output class from an existing one by adding new attributes |
|
related to confidence estimation: |
|
- sigma_1 (tensor) |
|
- sigma_2 (tensor) |
|
- kappa_u (tensor) |
|
- kappa_v (tensor) |
|
- fine_segm_confidence (tensor) |
|
- coarse_segm_confidence (tensor) |
|
|
|
Details on confidence estimation parameters can be found in: |
|
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning |
|
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 |
|
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 |
|
|
|
The new class inherits the provided `BasePredictorOutput` class, |
|
it's name is composed of the name of the provided class and |
|
"WithConfidences" suffix. |
|
|
|
Args: |
|
BasePredictorOutput (type): output type to which confidence data |
|
is to be added, assumed to be a dataclass |
|
Return: |
|
New dataclass derived from the provided one that has attributes |
|
for confidence estimation |
|
""" |
|
|
|
PredictorOutput = make_dataclass( |
|
BasePredictorOutput.__name__ + "WithConfidences", |
|
fields=[ |
|
("sigma_1", Optional[torch.Tensor], None), |
|
("sigma_2", Optional[torch.Tensor], None), |
|
("kappa_u", Optional[torch.Tensor], None), |
|
("kappa_v", Optional[torch.Tensor], None), |
|
("fine_segm_confidence", Optional[torch.Tensor], None), |
|
("coarse_segm_confidence", Optional[torch.Tensor], None), |
|
], |
|
bases=(BasePredictorOutput,), |
|
) |
|
|
|
|
|
|
|
def slice_if_not_none(data, item): |
|
if data is None: |
|
return None |
|
if isinstance(item, int): |
|
return data[item].unsqueeze(0) |
|
return data[item] |
|
|
|
def PredictorOutput_getitem(self, item): |
|
PredictorOutput = type(self) |
|
base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) |
|
return PredictorOutput( |
|
**base_predictor_output_sliced.__dict__, |
|
coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), |
|
fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item), |
|
sigma_1=slice_if_not_none(self.sigma_1, item), |
|
sigma_2=slice_if_not_none(self.sigma_2, item), |
|
kappa_u=slice_if_not_none(self.kappa_u, item), |
|
kappa_v=slice_if_not_none(self.kappa_v, item), |
|
) |
|
|
|
PredictorOutput.__getitem__ = PredictorOutput_getitem |
|
|
|
def PredictorOutput_to(self, device: torch.device): |
|
""" |
|
Transfers all tensors to the given device |
|
""" |
|
PredictorOutput = type(self) |
|
base_predictor_output_to = super(PredictorOutput, self).to(device) |
|
|
|
def to_device_if_tensor(var: Any): |
|
if isinstance(var, torch.Tensor): |
|
return var.to(device) |
|
return var |
|
|
|
return PredictorOutput( |
|
**base_predictor_output_to.__dict__, |
|
sigma_1=to_device_if_tensor(self.sigma_1), |
|
sigma_2=to_device_if_tensor(self.sigma_2), |
|
kappa_u=to_device_if_tensor(self.kappa_u), |
|
kappa_v=to_device_if_tensor(self.kappa_v), |
|
fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), |
|
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), |
|
) |
|
|
|
PredictorOutput.to = PredictorOutput_to |
|
return PredictorOutput |
|
|