Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Facebook, Inc. and its affiliates. | |
from dataclasses import dataclass | |
from typing import Union | |
import torch | |
class DensePoseChartPredictorOutput: | |
""" | |
Predictor output that contains segmentation and inner coordinates predictions for predefined | |
body parts: | |
* coarse segmentation, a tensor of shape [N, K, Hout, Wout] | |
* fine segmentation, a tensor of shape [N, C, Hout, Wout] | |
* U coordinates, a tensor of shape [N, C, Hout, Wout] | |
* V coordinates, a tensor of shape [N, C, Hout, Wout] | |
where | |
- N is the number of instances | |
- K is the number of coarse segmentation channels ( | |
2 = foreground / background, | |
15 = one of 14 body parts / background) | |
- C is the number of fine segmentation channels ( | |
24 fine body parts / background) | |
- Hout and Wout are height and width of predictions | |
""" | |
coarse_segm: torch.Tensor | |
fine_segm: torch.Tensor | |
u: torch.Tensor | |
v: torch.Tensor | |
def __len__(self): | |
""" | |
Number of instances (N) in the output | |
""" | |
return self.coarse_segm.size(0) | |
def __getitem__( | |
self, item: Union[int, slice, torch.BoolTensor] | |
) -> "DensePoseChartPredictorOutput": | |
""" | |
Get outputs for the selected instance(s) | |
Args: | |
item (int or slice or tensor): selected items | |
""" | |
if isinstance(item, int): | |
return DensePoseChartPredictorOutput( | |
coarse_segm=self.coarse_segm[item].unsqueeze(0), | |
fine_segm=self.fine_segm[item].unsqueeze(0), | |
u=self.u[item].unsqueeze(0), | |
v=self.v[item].unsqueeze(0), | |
) | |
else: | |
return DensePoseChartPredictorOutput( | |
coarse_segm=self.coarse_segm[item], | |
fine_segm=self.fine_segm[item], | |
u=self.u[item], | |
v=self.v[item], | |
) | |
def to(self, device: torch.device): | |
""" | |
Transfers all tensors to the given device | |
""" | |
coarse_segm = self.coarse_segm.to(device) | |
fine_segm = self.fine_segm.to(device) | |
u = self.u.to(device) | |
v = self.v.to(device) | |
return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v) | |