File size: 4,326 Bytes
7f51798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from abc import ABC, abstractmethod
from multiprocessing.pool import ThreadPool
from typing import List, Optional, Tuple, Union
from tqdm import tqdm

from pdb import set_trace as st
import numpy as np
import torch

from point_e.models.download import load_checkpoint

from npz_stream import NpzStreamer
from pointnet2_cls_ssg import get_model


def get_torch_devices() -> List[Union[str, torch.device]]:
    if torch.cuda.is_available():
        return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
    else:
        return ["cpu"]


class FeatureExtractor(ABC):
    @property
    @abstractmethod
    def supports_predictions(self) -> bool:
        pass

    @property
    @abstractmethod
    def feature_dim(self) -> int:
        pass

    @property
    @abstractmethod
    def num_classes(self) -> int:
        pass

    @abstractmethod
    def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
        """
        For a stream of point cloud batches, compute feature vectors and class
        predictions.

        :param point_clouds: a streamer for a sample batch. Typically, arr_0
                             will contain the XYZ coordinates.
        :return: a tuple (features, predictions)
                 - features: a [B x feature_dim] array of feature vectors.
                 - predictions: a [B x num_classes] array of probabilities.
        """


class PointNetClassifier(FeatureExtractor):
    def __init__(
        self,
        devices: List[Union[str, torch.device]],
        device_batch_size: int = 64,
        cache_dir: Optional[str] = None,
    ):
        state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[
            "model_state_dict"
        ]

        self.device_batch_size = device_batch_size
        self.devices = devices
        # self.models = []
        # for device in devices:
        model = get_model(num_class=40, normal_channel=False, width_mult=2)
        model.load_state_dict(state_dict)
        model.to('cuda')
        model.eval()
        # self.models.append(model)
        self.model = model

    @property
    def supports_predictions(self) -> bool:
        return True

    @property
    def feature_dim(self) -> int:
        return 256

    @property
    def num_classes(self) -> int:
        return 40

    # def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
    def features_and_preds(self, streamer) -> Tuple[np.ndarray, np.ndarray]:
        # batch_size = self.device_batch_size * len(self.devices)
        # batch_size = self.device_batch_size * len(self.devices)
        point_clouds = streamer # switch to pytorch stream here
        # point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"]))
        device = 'cuda'

        output_features = []
        output_predictions = []

        # st()

        # with ThreadPool(len(self.devices)) as pool:

        for _, batch in enumerate(tqdm(point_clouds)): # type: ignore
            # batch = normalize_point_clouds(batch)
            # batches = []
            # for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices):
                # batches.append(
            # batch = torch.from_numpy(batch).permute(0, 2, 1).to(dtype=torch.float32, device=device)
            batch = batch.to(dtype=torch.float32, device=device).permute(0, 2, 1) # B 3 L

            def compute_features(batch):
                # batch = i_batch
                with torch.no_grad():
                    return self.model(batch, features=True)

            # for logits, _, features in pool.imap(compute_features, enumerate(batches)):
            # for logits, _, features in pool.imap(compute_features, enumerate(batches)):

            logits, _, features = compute_features(batch)
            output_features.append(features.cpu().numpy())
            output_predictions.append(logits.exp().cpu().numpy())

        return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0)


def normalize_point_clouds(pc: np.ndarray) -> np.ndarray:
    centroids = np.mean(pc, axis=1, keepdims=True)
    pc = pc - centroids
    m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True)
    pc = pc / m
    return pc