|
from abc import ABC, abstractmethod |
|
from multiprocessing.pool import ThreadPool |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from point_e.models.download import load_checkpoint |
|
|
|
from .npz_stream import NpzStreamer |
|
from .pointnet2_cls_ssg import get_model |
|
|
|
|
|
def get_torch_devices() -> List[Union[str, torch.device]]: |
|
if torch.cuda.is_available(): |
|
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] |
|
else: |
|
return ["cpu"] |
|
|
|
|
|
class FeatureExtractor(ABC): |
|
@property |
|
@abstractmethod |
|
def supports_predictions(self) -> bool: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def feature_dim(self) -> int: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def num_classes(self) -> int: |
|
pass |
|
|
|
@abstractmethod |
|
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
For a stream of point cloud batches, compute feature vectors and class |
|
predictions. |
|
|
|
:param point_clouds: a streamer for a sample batch. Typically, arr_0 |
|
will contain the XYZ coordinates. |
|
:return: a tuple (features, predictions) |
|
- features: a [B x feature_dim] array of feature vectors. |
|
- predictions: a [B x num_classes] array of probabilities. |
|
""" |
|
|
|
|
|
class PointNetClassifier(FeatureExtractor): |
|
def __init__( |
|
self, |
|
devices: List[Union[str, torch.device]], |
|
device_batch_size: int = 64, |
|
cache_dir: Optional[str] = None, |
|
): |
|
state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[ |
|
"model_state_dict" |
|
] |
|
|
|
self.device_batch_size = device_batch_size |
|
self.devices = devices |
|
self.models = [] |
|
for device in devices: |
|
model = get_model(num_class=40, normal_channel=False, width_mult=2) |
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
model.eval() |
|
self.models.append(model) |
|
|
|
@property |
|
def supports_predictions(self) -> bool: |
|
return True |
|
|
|
@property |
|
def feature_dim(self) -> int: |
|
return 256 |
|
|
|
@property |
|
def num_classes(self) -> int: |
|
return 40 |
|
|
|
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: |
|
batch_size = self.device_batch_size * len(self.devices) |
|
point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"])) |
|
|
|
output_features = [] |
|
output_predictions = [] |
|
|
|
with ThreadPool(len(self.devices)) as pool: |
|
for batch in point_clouds: |
|
batch = normalize_point_clouds(batch) |
|
batches = [] |
|
for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices): |
|
batches.append( |
|
torch.from_numpy(batch[i : i + self.device_batch_size]) |
|
.permute(0, 2, 1) |
|
.to(dtype=torch.float32, device=device) |
|
) |
|
|
|
def compute_features(i_batch): |
|
i, batch = i_batch |
|
with torch.no_grad(): |
|
return self.models[i](batch, features=True) |
|
|
|
for logits, _, features in pool.imap(compute_features, enumerate(batches)): |
|
output_features.append(features.cpu().numpy()) |
|
output_predictions.append(logits.exp().cpu().numpy()) |
|
|
|
return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0) |
|
|
|
|
|
def normalize_point_clouds(pc: np.ndarray) -> np.ndarray: |
|
centroids = np.mean(pc, axis=1, keepdims=True) |
|
pc = pc - centroids |
|
m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True) |
|
pc = pc / m |
|
return pc |
|
|