Spaces:
Sleeping
Sleeping
File size: 2,759 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from contextlib import nullcontext
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from torch import Tensor
from .pvcnn.pvcnn import PVCNN2
from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus
from .simple.simple_model import SimplePointModel
class PointCloudModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
model_type: str = 'pvcnn',
in_channels: int = 3,
out_channels: int = 3,
embed_dim: int = 64,
dropout: float = 0.1,
width_multiplier: int = 1,
voxel_resolution_multiplier: int = 1,
):
super().__init__()
self.model_type = model_type
if self.model_type == 'pvcnn':
self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
self.model = PVCNN2(
embed_dim=embed_dim,
num_classes=out_channels,
extra_feature_channels=(in_channels - 3),
dropout=dropout, width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.model.classifier[-1].bias.data.normal_(0, 1e-6)
self.model.classifier[-1].weight.data.normal_(0, 1e-6)
elif self.model_type == 'pvcnnplusplus':
self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
self.model = PVCNN2PlusPlus(
embed_dim=embed_dim,
num_classes=out_channels,
extra_feature_channels=(in_channels - 3),
)
self.model.output_projection[-1].bias.data.normal_(0, 1e-6)
self.model.output_projection[-1].weight.data.normal_(0, 1e-6)
elif self.model_type == 'simple':
self.autocast_context = nullcontext()
self.model = SimplePointModel(
embed_dim=embed_dim,
num_classes=out_channels,
extra_feature_channels=(in_channels - 3),
)
self.model.output_projection.bias.data.normal_(0, 1e-6)
self.model.output_projection.weight.data.normal_(0, 1e-6)
else:
raise NotImplementedError()
def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor:
""" Receives input of shape (B, N, in_channels) and returns output
of shape (B, N, out_channels) """
with self.autocast_context:
if not ret_feats:
return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2)
else:
pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True)
return pred.transpose(1, 2), feats |