Maikou commited on
Commit
d1dec7c
1 Parent(s): 355f96f

Delete michelangelo/models/tsal/sal_transformer.py

Browse files
michelangelo/models/tsal/sal_transformer.py DELETED
@@ -1,286 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch_cluster import fps
6
- from typing import Optional
7
- import math
8
-
9
- from michelangelo.models.modules import checkpoint
10
- from michelangelo.models.modules.embedder import FourierEmbedder
11
- from michelangelo.models.modules.distributions import DiagonalGaussianDistribution
12
- from michelangelo.models.modules.transformer_blocks import (
13
- ResidualCrossAttentionBlock,
14
- Transformer
15
- )
16
-
17
- from .tsal_base import ShapeAsLatentModule
18
-
19
-
20
- class CrossAttentionEncoder(nn.Module):
21
-
22
- def __init__(self, *,
23
- device: Optional[torch.device],
24
- dtype: Optional[torch.dtype],
25
- num_latents: int,
26
- fourier_embedder: FourierEmbedder,
27
- point_feats: int,
28
- width: int,
29
- heads: int,
30
- init_scale: float = 0.25,
31
- qkv_bias: bool = True,
32
- use_ln_post: bool = False,
33
- use_checkpoint: bool = False):
34
-
35
- super().__init__()
36
-
37
- self.use_checkpoint = use_checkpoint
38
- self.num_latents = num_latents
39
- self.fourier_embedder = fourier_embedder
40
-
41
- self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
42
- self.cross_attn_encoder = ResidualCrossAttentionBlock(
43
- device=device,
44
- dtype=dtype,
45
- width=width,
46
- heads=heads,
47
- init_scale=init_scale,
48
- qkv_bias=qkv_bias
49
- )
50
- if use_ln_post:
51
- self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
52
- else:
53
- self.ln_post = None
54
-
55
- def _forward(self, pc, feats):
56
- """
57
-
58
- Args:
59
- pc (torch.FloatTensor): [B, N, 3]
60
- feats (torch.FloatTensor or None): [B, N, C]
61
-
62
- Returns:
63
-
64
- """
65
-
66
- B, N, _ = pc.shape
67
- batch = torch.arange(B).to(pc.device)
68
- batch = torch.repeat_interleave(batch, N)
69
-
70
- data = self.fourier_embedder(pc)
71
- if feats is not None:
72
- data = torch.cat([data, feats], dim=-1)
73
- data = self.input_proj(data)
74
-
75
- ratio = self.num_latents / N
76
- flatten_pos = pc.view(B * N, -1) # [B * N, 3]
77
- flatten_data = data.view(B * N, -1) # [B * N, C]
78
- idx = fps(flatten_pos, batch, ratio=ratio)
79
- center_pos = flatten_pos[idx].view(B, self.num_latents, -1)
80
- query = flatten_data[idx].view(B, self. num_latents, -1)
81
-
82
- latents = self.cross_attn_encoder(query, data)
83
-
84
- if self.ln_post is not None:
85
- latents = self.ln_post(latents)
86
-
87
- return latents, center_pos
88
-
89
- def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
90
- """
91
-
92
- Args:
93
- pc (torch.FloatTensor): [B, N, 3]
94
- feats (torch.FloatTensor or None): [B, N, C]
95
-
96
- Returns:
97
- dict
98
- """
99
-
100
- return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
101
-
102
-
103
- class CrossAttentionDecoder(nn.Module):
104
-
105
- def __init__(self, *,
106
- device: Optional[torch.device],
107
- dtype: Optional[torch.dtype],
108
- num_latents: int,
109
- out_channels: int,
110
- fourier_embedder: FourierEmbedder,
111
- width: int,
112
- heads: int,
113
- init_scale: float = 0.25,
114
- qkv_bias: bool = True,
115
- use_checkpoint: bool = False):
116
-
117
- super().__init__()
118
-
119
- self.use_checkpoint = use_checkpoint
120
- self.fourier_embedder = fourier_embedder
121
-
122
- self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
123
-
124
- self.cross_attn_decoder = ResidualCrossAttentionBlock(
125
- device=device,
126
- dtype=dtype,
127
- n_data=num_latents,
128
- width=width,
129
- heads=heads,
130
- init_scale=init_scale,
131
- qkv_bias=qkv_bias
132
- )
133
-
134
- self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
135
- self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
136
-
137
- def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
138
- queries = self.query_proj(self.fourier_embedder(queries))
139
- x = self.cross_attn_decoder(queries, latents)
140
- x = self.ln_post(x)
141
- x = self.output_proj(x)
142
- return x
143
-
144
- def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
145
- return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
146
-
147
-
148
- class ShapeAsLatentTransformer(ShapeAsLatentModule):
149
- def __init__(self, *,
150
- device: Optional[torch.device],
151
- dtype: Optional[torch.dtype],
152
- num_latents: int,
153
- point_feats: int = 0,
154
- embed_dim: int = 0,
155
- num_freqs: int = 8,
156
- include_pi: bool = True,
157
- width: int,
158
- layers: int,
159
- heads: int,
160
- init_scale: float = 0.25,
161
- qkv_bias: bool = True,
162
- use_ln_post: bool = False,
163
- use_checkpoint: bool = False):
164
-
165
- super().__init__()
166
-
167
- self.use_checkpoint = use_checkpoint
168
-
169
- self.num_latents = num_latents
170
- self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
171
-
172
- init_scale = init_scale * math.sqrt(1.0 / width)
173
- self.encoder = CrossAttentionEncoder(
174
- device=device,
175
- dtype=dtype,
176
- fourier_embedder=self.fourier_embedder,
177
- num_latents=num_latents,
178
- point_feats=point_feats,
179
- width=width,
180
- heads=heads,
181
- init_scale=init_scale,
182
- qkv_bias=qkv_bias,
183
- use_ln_post=use_ln_post,
184
- use_checkpoint=use_checkpoint
185
- )
186
-
187
- self.embed_dim = embed_dim
188
- if embed_dim > 0:
189
- # VAE embed
190
- self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
191
- self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
192
- self.latent_shape = (num_latents, embed_dim)
193
- else:
194
- self.latent_shape = (num_latents, width)
195
-
196
- self.transformer = Transformer(
197
- device=device,
198
- dtype=dtype,
199
- n_ctx=num_latents,
200
- width=width,
201
- layers=layers,
202
- heads=heads,
203
- init_scale=init_scale,
204
- qkv_bias=qkv_bias,
205
- use_checkpoint=use_checkpoint
206
- )
207
-
208
- # geometry decoder
209
- self.geo_decoder = CrossAttentionDecoder(
210
- device=device,
211
- dtype=dtype,
212
- fourier_embedder=self.fourier_embedder,
213
- out_channels=1,
214
- num_latents=num_latents,
215
- width=width,
216
- heads=heads,
217
- init_scale=init_scale,
218
- qkv_bias=qkv_bias,
219
- use_checkpoint=use_checkpoint
220
- )
221
-
222
- def encode(self,
223
- pc: torch.FloatTensor,
224
- feats: Optional[torch.FloatTensor] = None,
225
- sample_posterior: bool = True):
226
- """
227
-
228
- Args:
229
- pc (torch.FloatTensor): [B, N, 3]
230
- feats (torch.FloatTensor or None): [B, N, C]
231
- sample_posterior (bool):
232
-
233
- Returns:
234
- latents (torch.FloatTensor)
235
- center_pos (torch.FloatTensor):
236
- posterior (DiagonalGaussianDistribution or None):
237
- """
238
-
239
- latents, center_pos = self.encoder(pc, feats)
240
-
241
- posterior = None
242
- if self.embed_dim > 0:
243
- moments = self.pre_kl(latents)
244
- posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
245
-
246
- if sample_posterior:
247
- latents = posterior.sample()
248
- else:
249
- latents = posterior.mode()
250
-
251
- return latents, center_pos, posterior
252
-
253
- def decode(self, latents: torch.FloatTensor):
254
- latents = self.post_kl(latents)
255
- return self.transformer(latents)
256
-
257
- def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
258
- logits = self.geo_decoder(queries, latents).squeeze(-1)
259
- return logits
260
-
261
- def forward(self,
262
- pc: torch.FloatTensor,
263
- feats: torch.FloatTensor,
264
- volume_queries: torch.FloatTensor,
265
- sample_posterior: bool = True):
266
- """
267
-
268
- Args:
269
- pc (torch.FloatTensor): [B, N, 3]
270
- feats (torch.FloatTensor or None): [B, N, C]
271
- volume_queries (torch.FloatTensor): [B, P, 3]
272
- sample_posterior (bool):
273
-
274
- Returns:
275
- logits (torch.FloatTensor): [B, P]
276
- center_pos (torch.FloatTensor): [B, M, 3]
277
- posterior (DiagonalGaussianDistribution or None).
278
-
279
- """
280
-
281
- latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
282
-
283
- latents = self.decode(latents)
284
- logits = self.query_geometry(volume_queries, latents)
285
-
286
- return logits, center_pos, posterior