|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Union |
|
import torch |
|
|
|
|
|
@dataclass |
|
class DensePoseEmbeddingPredictorOutput: |
|
""" |
|
Predictor output that contains embedding and coarse segmentation data: |
|
* embedding: float tensor of size [N, D, H, W], contains estimated embeddings |
|
* coarse_segm: float tensor of size [N, K, H, W] |
|
Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
|
K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS |
|
""" |
|
|
|
embedding: torch.Tensor |
|
coarse_segm: torch.Tensor |
|
|
|
def __len__(self): |
|
""" |
|
Number of instances (N) in the output |
|
""" |
|
return self.coarse_segm.size(0) |
|
|
|
def __getitem__( |
|
self, item: Union[int, slice, torch.BoolTensor] |
|
) -> "DensePoseEmbeddingPredictorOutput": |
|
""" |
|
Get outputs for the selected instance(s) |
|
|
|
Args: |
|
item (int or slice or tensor): selected items |
|
""" |
|
if isinstance(item, int): |
|
return DensePoseEmbeddingPredictorOutput( |
|
coarse_segm=self.coarse_segm[item].unsqueeze(0), |
|
embedding=self.embedding[item].unsqueeze(0), |
|
) |
|
else: |
|
return DensePoseEmbeddingPredictorOutput( |
|
coarse_segm=self.coarse_segm[item], embedding=self.embedding[item] |
|
) |
|
|
|
def to(self, device: torch.device): |
|
""" |
|
Transfers all tensors to the given device |
|
""" |
|
coarse_segm = self.coarse_segm.to(device) |
|
embedding = self.embedding.to(device) |
|
return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding) |
|
|