Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # pyre-unsafe | |
| from dataclasses import dataclass | |
| from typing import Union | |
| import torch | |
| class DensePoseEmbeddingPredictorOutput: | |
| """ | |
| Predictor output that contains embedding and coarse segmentation data: | |
| * embedding: float tensor of size [N, D, H, W], contains estimated embeddings | |
| * coarse_segm: float tensor of size [N, K, H, W] | |
| Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE | |
| K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS | |
| """ | |
| embedding: torch.Tensor | |
| coarse_segm: torch.Tensor | |
| def __len__(self): | |
| """ | |
| Number of instances (N) in the output | |
| """ | |
| return self.coarse_segm.size(0) | |
| def __getitem__( | |
| self, item: Union[int, slice, torch.BoolTensor] | |
| ) -> "DensePoseEmbeddingPredictorOutput": | |
| """ | |
| Get outputs for the selected instance(s) | |
| Args: | |
| item (int or slice or tensor): selected items | |
| """ | |
| if isinstance(item, int): | |
| return DensePoseEmbeddingPredictorOutput( | |
| coarse_segm=self.coarse_segm[item].unsqueeze(0), | |
| embedding=self.embedding[item].unsqueeze(0), | |
| ) | |
| else: | |
| return DensePoseEmbeddingPredictorOutput( | |
| coarse_segm=self.coarse_segm[item], embedding=self.embedding[item] | |
| ) | |
| def to(self, device: torch.device): | |
| """ | |
| Transfers all tensors to the given device | |
| """ | |
| coarse_segm = self.coarse_segm.to(device) | |
| embedding = self.embedding.to(device) | |
| return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding) | |