File size: 15,718 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import torch
import torch.nn.functional as tfunc
import torch_scatter

def prepend_dummies(
        vertices:torch.Tensor, #V,D
        faces:torch.Tensor, #F,3 long
    )->tuple[torch.Tensor,torch.Tensor]:
    """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
    V,D = vertices.shape
    vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
    faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
    return vertices,faces

def remove_dummies(
        vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
        faces:torch.Tensor, #F,3 long - first face all zeros
    )->tuple[torch.Tensor,torch.Tensor]:
    """remove dummy elements added with prepend_dummies()"""
    return vertices[1:],faces[1:]-1


def calc_edges(
        faces: torch.Tensor,  # F,3 long - first face may be dummy with all zeros
        with_edge_to_face: bool = False
    ) -> tuple[torch.Tensor, ...]:
    """
    returns tuple of
    - edges E,2 long, 0 for unused, lower vertex index first
    - face_to_edge F,3 long
    - (optional) edge_to_face shape=E,[left,right],[face,side]

    o-<-----e1     e0,e1...edge, e0<e1
    |      /A      L,R....left and right face
    |  L /  |      both triangles ordered counter clockwise
    |  / R  |      normals pointing out of screen
    V/      |      
    e0---->-o     
    """

    F = faces.shape[0]
    
    # make full edges, lower vertex index first
    face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
    full_edges = face_edges.reshape(F*3,2)
    sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 TODO min/max faster?

    # make unique edges
    edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
    E = edges.shape[0]
    face_to_edge = full_to_unique.reshape(F,3) #F,3

    if not with_edge_to_face:
        return edges, face_to_edge

    is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
    edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
    scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
    edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
    edge_to_face[0] = 0
    return edges, face_to_edge, edge_to_face

def calc_edge_length(
        vertices:torch.Tensor, #V,3 first may be dummy
        edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
        )->torch.Tensor: #E

    full_vertices = vertices[edges] #E,2,3
    a,b = full_vertices.unbind(dim=1) #E,3
    return torch.norm(a-b,p=2,dim=-1)

def calc_face_normals(
        vertices:torch.Tensor, #V,3 first vertex may be unreferenced
        faces:torch.Tensor, #F,3 long, first face may be all zero
        normalize:bool=False,
        )->torch.Tensor: #F,3
    """
         n
         |
         c0     corners ordered counterclockwise when
        / \     looking onto surface (in neg normal direction)
      c1---c2
    """
    full_vertices = vertices[faces] #F,C=3,3
    v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
    face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
    if normalize:
        face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) #TODO inplace?
    return face_normals #F,3

def calc_vertex_normals(
        vertices:torch.Tensor, #V,3 first vertex may be unreferenced
        faces:torch.Tensor, #F,3 long, first face may be all zero
        face_normals:torch.Tensor=None, #F,3, not normalized
        )->torch.Tensor: #F,3

    F = faces.shape[0]

    if face_normals is None:
        face_normals = calc_face_normals(vertices,faces)
    
    vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
    vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
    vertex_normals = vertex_normals.sum(dim=1) #V,3
    return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)

def calc_face_ref_normals(
        faces:torch.Tensor, #F,3 long, 0 for unused
        vertex_normals:torch.Tensor, #V,3 first unused
        normalize:bool=False,
        )->torch.Tensor: #F,3
    """calculate reference normals for face flip detection"""
    full_normals = vertex_normals[faces] #F,C=3,3
    ref_normals = full_normals.sum(dim=1) #F,3
    if normalize:
        ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
    return ref_normals

def pack(
        vertices:torch.Tensor, #V,3 first unused and nan
        faces:torch.Tensor, #F,3 long, 0 for unused
        )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
    """removes unused elements in vertices and faces"""
    V = vertices.shape[0]
    
    # remove unused faces
    used_faces = faces[:,0]!=0
    used_faces[0] = True
    faces = faces[used_faces] #sync

    # remove unused vertices
    used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
    used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') #TODO int faster?
    used_vertices = used_vertices.any(dim=1)
    used_vertices[0] = True
    vertices = vertices[used_vertices] #sync

    # update used faces
    ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
    V1 = used_vertices.sum()
    ind[used_vertices] =  torch.arange(0,V1,device=vertices.device) #sync
    faces = ind[faces]

    return vertices,faces

def split_edges(
        vertices:torch.Tensor, #V,3 first unused
        faces:torch.Tensor, #F,3 long, 0 for unused
        edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
        face_to_edge:torch.Tensor, #F,3 long 0 for unused
        splits, #E bool
        pack_faces:bool=True,
        )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)

    #   c2                    c2               c...corners = faces
    #    . .                   . .             s...side_vert, 0 means no split
    #    .   .                 .N2 .           S...shrunk_face
    #    .     .               .     .         Ni...new_faces
    #   s2      s1           s2|c2...s1|c1
    #    .        .            .     .  .
    #    .          .          . S .      .
    #    .            .        . .     N1    .
    #   c0...(s0=0)....c1    s0|c0...........c1
    #
    # pseudo-code:
    #   S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
    #   split = side_vert!=0 example:[False,True,True]
    #   N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
    #   N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
    #   N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]

    V = vertices.shape[0]
    F = faces.shape[0]
    S = splits.sum().item() #sync

    if S==0:
        return vertices,faces
    
    edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
    edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
    side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
    split_edges = edges[splits] #S sync

    #vertices
    split_vertices = vertices[split_edges].mean(dim=1) #S,3
    vertices = torch.concat((vertices,split_vertices),dim=0)

    #faces
    side_split = side_vert!=0 #F,3
    shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
    new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
    faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
    if pack_faces:
        mask = faces[:,0]!=0
        mask[0] = True
        faces = faces[mask] #F',3 sync

    return vertices,faces

def collapse_edges(
        vertices:torch.Tensor, #V,3 first unused
        faces:torch.Tensor, #F,3 long 0 for unused
        edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
        priorities:torch.Tensor, #E float
        stable:bool=False, #only for unit testing
        )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
        
    V = vertices.shape[0]
    
    # check spacing
    _,order = priorities.sort(stable=stable) #E
    rank = torch.zeros_like(order)
    rank[order] = torch.arange(0,len(rank),device=rank.device)
    vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
    edge_rank = rank #E
    for i in range(3):
        torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
        edge_rank,_ = vert_rank[edges].max(dim=-1) #E
    candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2

    # check connectivity
    vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
    vert_connections[candidates[:,0]] = 1 #start
    edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
    vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
    vert_connections[candidates] = 0 #clear start and end
    edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
    vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
    collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end

    # mean vertices
    vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) #TODO dim?

    # update faces
    dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
    dest[collapses[:,1]] = dest[collapses[:,0]]
    faces = dest[faces] #F,3 TODO optimize?
    c0,c1,c2 = faces.unbind(dim=-1)
    collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
    faces[collapsed] = 0

    return vertices,faces

def calc_face_collapses(
        vertices:torch.Tensor, #V,3 first unused
        faces:torch.Tensor, #F,3 long, 0 for unused
        edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
        face_to_edge:torch.Tensor, #F,3 long 0 for unused
        edge_length:torch.Tensor, #E
        face_normals:torch.Tensor, #F,3
        vertex_normals:torch.Tensor, #V,3 first unused
        min_edge_length:torch.Tensor=None, #V
        area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
        shortest_probability = 0.8
        )->torch.Tensor: #E edges to collapse
    
    E = edges.shape[0]
    F = faces.shape[0]

    # face flips
    ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
    face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
    
    # small faces
    if min_edge_length is not None:
        min_face_length = min_edge_length[faces].mean(dim=-1) #F
        min_area = min_face_length**2 * area_ratio #F
        face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
        face_collapses[0] = False

    # faces to edges
    face_length = edge_length[face_to_edge] #F,3

    if shortest_probability<1:
        #select shortest edge with shortest_probability chance
        randlim = round(2/(1-shortest_probability))
        rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
        sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
        local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
    else:
        local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
    
    edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
    edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
    edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) #TODO legal for bool?

    return edge_collapses.bool()

def flip_edges(
        vertices:torch.Tensor, #V,3 first unused
        faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
        edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
        edge_to_face:torch.Tensor, #E,[left,right],[face,side]
        with_border:bool=True, #handle border edges (D=4 instead of D=6)
        with_normal_check:bool=True, #check face normal flips
        stable:bool=False, #only for unit testing
        ):
    V = vertices.shape[0]
    E = edges.shape[0]
    device=vertices.device
    vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
    vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
    neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
    neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
    edge_is_inside = neighbors.all(dim=-1) #E

    if with_border:
        # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
        # need to use float for masks in order to use scatter(reduce='multiply')
        vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
        src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
        vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
        vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
        vertex_degree -= 2 * vertex_is_inside #V long

    neighbor_degrees = vertex_degree[neighbors] #E,LR=2
    edge_degrees = vertex_degree[edges] #E,2
    #
    # loss = Sum_over_affected_vertices((new_degree-6)**2)
    # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
    #                   + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
    #             = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
    #
    loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
    candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
    loss_change = loss_change[candidates] #E'
    if loss_change.shape[0]==0:
        return

    edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
    _,order = loss_change.sort(descending=True, stable=stable) #E'
    rank = torch.zeros_like(order)
    rank[order] = torch.arange(0,len(rank),device=rank.device)
    vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
    torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
    vertex_rank,_ = vertex_rank.max(dim=-1) #V
    neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
    flip = rank==neighborhood_rank #E'

    if with_normal_check:
        #  cl-<-----e1     e0,e1...edge, e0<e1
        #   |      /A      L,R....left and right face
        #   |  L /  |      both triangles ordered counter clockwise
        #   |  / R  |      normals pointing out of screen
        #   V/      |      
        #   e0---->-cr    
        v = vertices[edges_neighbors] #E",4,3
        v = v - v[:,0:1] #make relative to e0 
        e1 = v[:,1]
        cl = v[:,2]
        cr = v[:,3]
        n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors 
        flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
        flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face

    flip_edges_neighbors = edges_neighbors[flip] #E",4
    flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
    flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
    faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))