Roopansh's picture
Initial Commit
73c83cf
raw
history blame
2.32 kB
# Copyright (c) Facebook, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Union
import torch
@dataclass
class DensePoseChartPredictorOutput:
"""
Predictor output that contains segmentation and inner coordinates predictions for predefined
body parts:
* coarse segmentation, a tensor of shape [N, K, Hout, Wout]
* fine segmentation, a tensor of shape [N, C, Hout, Wout]
* U coordinates, a tensor of shape [N, C, Hout, Wout]
* V coordinates, a tensor of shape [N, C, Hout, Wout]
where
- N is the number of instances
- K is the number of coarse segmentation channels (
2 = foreground / background,
15 = one of 14 body parts / background)
- C is the number of fine segmentation channels (
24 fine body parts / background)
- Hout and Wout are height and width of predictions
"""
coarse_segm: torch.Tensor
fine_segm: torch.Tensor
u: torch.Tensor
v: torch.Tensor
def __len__(self):
"""
Number of instances (N) in the output
"""
return self.coarse_segm.size(0)
def __getitem__(
self, item: Union[int, slice, torch.BoolTensor]
) -> "DensePoseChartPredictorOutput":
"""
Get outputs for the selected instance(s)
Args:
item (int or slice or tensor): selected items
"""
if isinstance(item, int):
return DensePoseChartPredictorOutput(
coarse_segm=self.coarse_segm[item].unsqueeze(0),
fine_segm=self.fine_segm[item].unsqueeze(0),
u=self.u[item].unsqueeze(0),
v=self.v[item].unsqueeze(0),
)
else:
return DensePoseChartPredictorOutput(
coarse_segm=self.coarse_segm[item],
fine_segm=self.fine_segm[item],
u=self.u[item],
v=self.v[item],
)
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
coarse_segm = self.coarse_segm.to(device)
fine_segm = self.fine_segm.to(device)
u = self.u.to(device)
v = self.v.to(device)
return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v)