File size: 28,710 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
import torch
import numpy as np
import logging
from copy import deepcopy
from .utils.libkdtree import KDTree

logger_py = logging.getLogger(__name__)


def compute_iou(occ1, occ2):
    ''' Computes the Intersection over Union (IoU) value for two sets of
    occupancy values.
    Args:
        occ1 (tensor): first set of occupancy values
        occ2 (tensor): second set of occupancy values
    '''
    occ1 = np.asarray(occ1)
    occ2 = np.asarray(occ2)

    # Put all data in second dimension
    # Also works for 1-dimensional data
    if occ1.ndim >= 2:
        occ1 = occ1.reshape(occ1.shape[0], -1)
    if occ2.ndim >= 2:
        occ2 = occ2.reshape(occ2.shape[0], -1)

    # Convert to boolean values
    occ1 = (occ1 >= 0.5)
    occ2 = (occ2 >= 0.5)

    # Compute IOU
    area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1)
    area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1)

    iou = (area_intersect / area_union)

    return iou


def rgb2gray(rgb):
    ''' rgb of size B x h x w x 3
    '''
    r, g, b = rgb[:, :, :, 0], rgb[:, :, :, 1], rgb[:, :, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray


def sample_patch_points(
    batch_size, n_points, patch_size=1, image_resolution=(128, 128), continuous=True
):
    ''' Returns sampled points in the range [-1, 1].

    Args:
        batch_size (int): required batch size
        n_points (int): number of points to sample
        patch_size (int): size of patch; if > 1, patches of size patch_size
            are sampled instead of individual points
        image_resolution (tuple): image resolution (required for calculating
            the pixel distances)
        continuous (bool): whether to sample continuously or only on pixel
            locations
    '''
    assert (patch_size > 0)
    # Calculate step size for [-1, 1] that is equivalent to a pixel in
    # original resolution
    h_step = 1. / image_resolution[0]
    w_step = 1. / image_resolution[1]
    # Get number of patches
    patch_size_squared = patch_size**2
    n_patches = int(n_points / patch_size_squared)
    if continuous:
        p = torch.rand(batch_size, n_patches, 2)    # [0, 1]
    else:
        px = torch.randint(0, image_resolution[1],
                           size=(batch_size, n_patches, 1)).float() / (image_resolution[1] - 1)
        py = torch.randint(0, image_resolution[0],
                           size=(batch_size, n_patches, 1)).float() / (image_resolution[0] - 1)
        p = torch.cat([px, py], dim=-1)
    # Scale p to [0, (1 - (patch_size - 1) * step) ]
    p[:, :, 0] *= 1 - (patch_size - 1) * w_step
    p[:, :, 1] *= 1 - (patch_size - 1) * h_step

    # Add points
    patch_arange = torch.arange(patch_size)
    x_offset, y_offset = torch.meshgrid(patch_arange, patch_arange)
    patch_offsets = torch.stack([x_offset.reshape(-1), y_offset.reshape(-1)],
                                dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float()

    patch_offsets[:, :, :, 0] *= w_step
    patch_offsets[:, :, :, 1] *= h_step

    # Add patch_offsets to points
    p = p.view(batch_size, n_patches, 1, 2) + patch_offsets

    # Scale to [-1, x]
    p = p * 2 - 1

    p = p.view(batch_size, -1, 2)

    amax, amin = p.max(), p.min()
    assert (amax <= 1. and amin >= -1.)

    return p


def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, n_steps=40):
    ''' Returns n_steps equally spaced points inside the unit cube on the rays
    cast from ray0 with direction ray_direction.

    This function is used to get the ray marching points {p^ray_j} for a given
    camera position ray0 and
    a given ray direction ray_direction which goes from the camera_position to
    the pixel location.

    NOTE: The returned values d_proposal are the lengths of the ray:
        p^ray_j = ray0 + d_proposal_j * ray_direction

    Args:
        ray0 (tensor): Start positions of the rays
        ray_direction (tensor): Directions of rays
        padding (float): Padding which is applied to the unit cube
        eps (float): The epsilon value for numerical stability
        n_steps (int): number of steps
    '''
    batch_size, n_pts, _ = ray0.shape
    device = ray0.device

    p_intervals, d_intervals, mask_inside_cube = \
        check_ray_intersection_with_unit_cube(ray0, ray_direction, padding,
                                              eps)
    d_proposal = d_intervals[:, :, 0].unsqueeze(-1) + \
        torch.linspace(0, 1, steps=n_steps).to(device).view(1, 1, -1) * \
        (d_intervals[:, :, 1] - d_intervals[:, :, 0]).unsqueeze(-1)
    d_proposal = d_proposal.unsqueeze(-1)

    return d_proposal, mask_inside_cube


def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, scale=2.0):
    ''' Checks if rays ray0 + d * ray_direction intersect with unit cube with
    padding padding.

    It returns the two intersection points as well as the sorted ray lengths d.

    Args:
        ray0 (tensor): Start positions of the rays
        ray_direction (tensor): Directions of rays
        padding (float): Padding which is applied to the unit cube
        eps (float): The epsilon value for numerical stability
        scale (float): cube size
    '''
    batch_size, n_pts, _ = ray0.shape
    device = ray0.device

    # calculate intersections with unit cube (< . , . >  is the dot product)
    # <n, x - p> = <n, ray0 + d * ray_direction - p_e> = 0
    # d = - <n, ray0 - p_e> / <n, ray_direction>

    # Get points on plane p_e
    p_distance = (scale * 0.5) + padding / 2
    p_e = torch.ones(batch_size, n_pts, 6).to(device) * p_distance
    p_e[:, :, 3:] *= -1.

    # Calculate the intersection points with given formula
    nominator = p_e - ray0.repeat(1, 1, 2)
    denominator = ray_direction.repeat(1, 1, 2)
    d_intersect = nominator / denominator
    p_intersect = ray0.unsqueeze(-2) + d_intersect.unsqueeze(-1) * \
        ray_direction.unsqueeze(-2)

    # Calculate mask where points intersect unit cube
    p_mask_inside_cube = (
        (p_intersect[:, :, :, 0] <= p_distance + eps) &
        (p_intersect[:, :, :, 1] <= p_distance + eps) &
        (p_intersect[:, :, :, 2] <= p_distance + eps) &
        (p_intersect[:, :, :, 0] >= -(p_distance + eps)) &
        (p_intersect[:, :, :, 1] >= -(p_distance + eps)) &
        (p_intersect[:, :, :, 2] >= -(p_distance + eps))
    ).cpu()

    # Correct rays are these which intersect exactly 2 times
    mask_inside_cube = p_mask_inside_cube.sum(-1) == 2

    # Get interval values for p's which are valid
    p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[mask_inside_cube]].view(-1, 2, 3)
    p_intervals_batch = torch.zeros(batch_size, n_pts, 2, 3).to(device)
    p_intervals_batch[mask_inside_cube] = p_intervals

    # Calculate ray lengths for the interval points
    d_intervals_batch = torch.zeros(batch_size, n_pts, 2).to(device)
    norm_ray = torch.norm(ray_direction[mask_inside_cube], dim=-1)
    d_intervals_batch[mask_inside_cube] = torch.stack(
        [
            torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray,
            torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray,
        ],
        dim=-1
    )

    # Sort the ray lengths
    d_intervals_batch, indices_sort = d_intervals_batch.sort()
    p_intervals_batch = p_intervals_batch[torch.arange(batch_size).view(-1, 1, 1),
                                          torch.arange(n_pts).view(1, -1, 1), indices_sort]

    return p_intervals_batch, d_intervals_batch, mask_inside_cube


def intersect_camera_rays_with_unit_cube(
    pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6, use_ray_length_as_depth=True
):
    ''' Returns the intersection points of ray cast from camera origin to
    pixel points p on the image plane.

    The function returns the intersection points as well the depth values and
    a mask specifying which ray intersects the unit cube.

    Args:
        pixels (tensor): Pixel points on image plane (range [-1, 1])
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        padding (float): Padding which is applied to the unit cube
        eps (float): The epsilon value for numerical stability

    '''
    batch_size, n_points, _ = pixels.shape

    pixel_world = image_points_to_world(pixels, camera_mat, world_mat, scale_mat)
    camera_world = origin_to_world(n_points, camera_mat, world_mat, scale_mat)
    ray_vector = (pixel_world - camera_world)

    p_cube, d_cube, mask_cube = check_ray_intersection_with_unit_cube(
        camera_world, ray_vector, padding=padding, eps=eps
    )
    if not use_ray_length_as_depth:
        p_cam = transform_to_camera_space(
            p_cube.view(batch_size, -1, 3), camera_mat, world_mat, scale_mat
        ).view(batch_size, n_points, -1, 3)
        d_cube = p_cam[:, :, :, -1]
    return p_cube, d_cube, mask_cube


def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), subsample_to=None):
    ''' Arranges pixels for given resolution in range image_range.

    The function returns the unscaled pixel locations as integers and the
    scaled float values.

    Args:
        resolution (tuple): image resolution
        batch_size (int): batch size
        image_range (tuple): range of output points (default [-1, 1])
        subsample_to (int): if integer and > 0, the points are randomly
            subsampled to this value
    '''
    h, w = resolution
    n_points = resolution[0] * resolution[1]

    # Arrange pixel location in scale resolution
    pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))
    pixel_locations = torch.stack([pixel_locations[0], pixel_locations[1]],
                                  dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
    pixel_scaled = pixel_locations.clone().float()

    # Shift and scale points to match image_range
    scale = (image_range[1] - image_range[0])
    loc = scale / 2
    pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc
    pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc

    # Subsample points if subsample_to is not None and > 0
    if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points):
        idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to, ), replace=False)
        pixel_scaled = pixel_scaled[:, idx]
        pixel_locations = pixel_locations[:, idx]

    return pixel_locations, pixel_scaled


def to_pytorch(tensor, return_type=False):
    ''' Converts input tensor to pytorch.

    Args:
        tensor (tensor): Numpy or Pytorch tensor
        return_type (bool): whether to return input type
    '''
    is_numpy = False
    if type(tensor) == np.ndarray:
        tensor = torch.from_numpy(tensor)
        is_numpy = True
    tensor = tensor.clone()
    if return_type:
        return tensor, is_numpy
    return tensor


def get_mask(tensor):
    ''' Returns mask of non-illegal values for tensor.

    Args:
        tensor (tensor): Numpy or Pytorch tensor
    '''
    tensor, is_numpy = to_pytorch(tensor, True)
    mask = ((abs(tensor) != np.inf) & (torch.isnan(tensor) == False))
    mask = mask.to(torch.bool)
    if is_numpy:
        mask = mask.numpy()

    return mask


def transform_mesh(mesh, transform):
    ''' Transforms a mesh with given transformation.

    Args:
        mesh (trimesh mesh): mesh
        transform (tensor): transformation matrix of size 4 x 4
    '''
    mesh = deepcopy(mesh)
    v = np.asarray(mesh.vertices).astype(np.float32)
    v_transformed = transform_pointcloud(v, transform)
    mesh.vertices = v_transformed
    return mesh


def transform_pointcloud(pointcloud, transform):
    ''' Transforms a point cloud with given transformation.

    Args:
        pointcloud (tensor): tensor of size N x 3
        transform (tensor): transformation of size 4 x 4
    '''

    assert (transform.shape == (4, 4) and pointcloud.shape[-1] == 3)

    pcl, is_numpy = to_pytorch(pointcloud, True)
    transform = to_pytorch(transform)

    # Transform point cloud to homogen coordinate system
    pcl_hom = torch.cat([pcl, torch.ones(pcl.shape[0], 1)], dim=-1).transpose(1, 0)

    # Apply transformation to point cloud
    pcl_hom_transformed = transform @ pcl_hom

    # Transform back to 3D coordinates
    pcl_out = pcl_hom_transformed[:3].transpose(1, 0)
    if is_numpy:
        pcl_out = pcl_out.numpy()

    return pcl_out


def transform_points_batch(p, transform):
    ''' Transform points tensor with given transform.

    Args:
        p (tensor): tensor of size B x N x 3
        transform (tensor): transformation of size B x 4 x 4
    '''
    device = p.device
    assert (transform.shape[1:] == (4, 4) and p.shape[-1] == 3 and p.shape[0] == transform.shape[0])

    # Transform points to homogen coordinates
    pcl_hom = torch.cat([p, torch.ones(p.shape[0], p.shape[1], 1).to(device)],
                        dim=-1).transpose(2, 1)

    # Apply transformation
    pcl_hom_transformed = transform @ pcl_hom

    # Transform back to 3D coordinates
    pcl_out = pcl_hom_transformed[:, :3].transpose(2, 1)
    return pcl_out


def get_tensor_values(
    tensor, p, grid_sample=True, mode='nearest', with_mask=False, squeeze_channel_dim=False
):
    '''
    Returns values from tensor at given location p.

    Args:
        tensor (tensor): tensor of size B x C x H x W
        p (tensor): position values scaled between [-1, 1] and
            of size B x N x 2
        grid_sample (boolean): whether to use grid sampling
        mode (string): what mode to perform grid sampling in
        with_mask (bool): whether to return the mask for invalid values
        squeeze_channel_dim (bool): whether to squeeze the channel dimension
            (only applicable to 1D data)
    '''
    p = to_pytorch(p)
    tensor, is_numpy = to_pytorch(tensor, True)
    batch_size, _, h, w = tensor.shape

    if grid_sample:
        p = p.unsqueeze(1)
        values = torch.nn.functional.grid_sample(tensor, p, mode=mode)
        values = values.squeeze(2)
        values = values.permute(0, 2, 1)
    else:
        p[:, :, 0] = (p[:, :, 0] + 1) * (w) / 2
        p[:, :, 1] = (p[:, :, 1] + 1) * (h) / 2
        p = p.long()
        values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1], p[:, :, 0]]

    if with_mask:
        mask = get_mask(values)
        if squeeze_channel_dim:
            mask = mask.squeeze(-1)
        if is_numpy:
            mask = mask.numpy()

    if squeeze_channel_dim:
        values = values.squeeze(-1)

    if is_numpy:
        values = values.numpy()

    if with_mask:
        return values, mask
    return values


def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, invert=True):
    ''' Transforms pixel positions p with given depth value d to world coordinates.

    Args:
        pixels (tensor): pixel tensor of size B x N x 2
        depth (tensor): depth tensor of size B x N x 1
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert matrices (default: true)
    '''
    assert (pixels.shape[-1] == 2)

    # Convert to pytorch
    pixels, is_numpy = to_pytorch(pixels, True)
    depth = to_pytorch(depth)
    camera_mat = to_pytorch(camera_mat)
    world_mat = to_pytorch(world_mat)
    scale_mat = to_pytorch(scale_mat)

    # Invert camera matrices
    if invert:
        camera_mat = torch.inverse(camera_mat)
        world_mat = torch.inverse(world_mat)
        scale_mat = torch.inverse(scale_mat)

    # Transform pixels to homogen coordinates
    pixels = pixels.permute(0, 2, 1)
    pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)

    # Project pixels into camera space
    pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)

    # Transform pixels to world space
    p_world = scale_mat @ world_mat @ camera_mat @ pixels

    # Transform p_world back to 3D coordinates
    p_world = p_world[:, :3].permute(0, 2, 1)

    if is_numpy:
        p_world = p_world.numpy()
    return p_world


def transform_to_camera_space(p_world, camera_mat, world_mat, scale_mat):
    ''' Transforms world points to camera space.
        Args:
        p_world (tensor): world points tensor of size B x N x 3
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
    '''
    batch_size, n_p, _ = p_world.shape
    device = p_world.device

    # Transform world points to homogen coordinates
    p_world = torch.cat([p_world, torch.ones(batch_size, n_p, 1).to(device)],
                        dim=-1).permute(0, 2, 1)

    # Apply matrices to transform p_world to camera space
    p_cam = camera_mat @ world_mat @ scale_mat @ p_world

    # Transform points back to 3D coordinates
    p_cam = p_cam[:, :3].permute(0, 2, 1)
    return p_cam


def origin_to_world(n_points, camera_mat, world_mat, scale_mat, invert=True):
    ''' Transforms origin (camera location) to world coordinates.

    Args:
        n_points (int): how often the transformed origin is repeated in the
            form (batch_size, n_points, 3)
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert the matrices (default: true)
    '''
    batch_size = camera_mat.shape[0]
    device = camera_mat.device

    # Create origin in homogen coordinates
    p = torch.zeros(batch_size, 4, n_points).to(device)
    p[:, -1] = 1.

    # Invert matrices
    if invert:
        camera_mat = torch.inverse(camera_mat)
        world_mat = torch.inverse(world_mat)
        scale_mat = torch.inverse(scale_mat)

    # Apply transformation
    p_world = scale_mat @ world_mat @ camera_mat @ p

    # Transform points back to 3D coordinates
    p_world = p_world[:, :3].permute(0, 2, 1)
    return p_world


def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, invert=True):
    ''' Transforms points on image plane to world coordinates.

    In contrast to transform_to_world, no depth value is needed as points on
    the image plane have a fixed depth of 1.

    Args:
        image_points (tensor): image points tensor of size B x N x 2
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        invert (bool): whether to invert matrices (default: true)
    '''
    batch_size, n_pts, dim = image_points.shape
    assert (dim == 2)
    device = image_points.device

    d_image = torch.ones(batch_size, n_pts, 1).to(device)
    return transform_to_world(
        image_points, d_image, camera_mat, world_mat, scale_mat, invert=invert
    )


def check_weights(params):
    ''' Checks weights for illegal values.

    Args:
        params (tensor): parameter tensor
    '''
    for k, v in params.items():
        if torch.isnan(v).any():
            logger_py.warn('NaN Values detected in model weight %s.' % k)


def check_tensor(tensor, tensorname='', input_tensor=None):
    ''' Checks tensor for illegal values.

    Args:
        tensor (tensor): tensor
        tensorname (string): name of tensor
        input_tensor (tensor): previous input
    '''
    if torch.isnan(tensor).any():
        logger_py.warn('Tensor %s contains nan values.' % tensorname)
        if input_tensor is not None:
            logger_py.warn(f'Input was: {input_tensor}')


def get_prob_from_logits(logits):
    ''' Returns probabilities for logits

    Args:
        logits (tensor): logits
    '''
    odds = np.exp(logits)
    probs = odds / (1 + odds)
    return probs


def get_logits_from_prob(probs, eps=1e-4):
    ''' Returns logits for probabilities.

    Args:
        probs (tensor): probability tensor
        eps (float): epsilon value for numerical stability
    '''
    probs = np.clip(probs, a_min=eps, a_max=1 - eps)
    logits = np.log(probs / (1 - probs))
    return logits


def chamfer_distance(points1, points2, use_kdtree=True, give_id=False):
    ''' Returns the chamfer distance for the sets of points.

    Args:
        points1 (numpy array): first point set
        points2 (numpy array): second point set
        use_kdtree (bool): whether to use a kdtree
        give_id (bool): whether to return the IDs of nearest points
    '''
    if use_kdtree:
        return chamfer_distance_kdtree(points1, points2, give_id=give_id)
    else:
        return chamfer_distance_naive(points1, points2)


def chamfer_distance_naive(points1, points2):
    ''' Naive implementation of the Chamfer distance.

    Args:
        points1 (numpy array): first point set
        points2 (numpy array): second point set
    '''
    assert (points1.size() == points2.size())
    batch_size, T, _ = points1.size()

    points1 = points1.view(batch_size, T, 1, 3)
    points2 = points2.view(batch_size, 1, T, 3)

    distances = (points1 - points2).pow(2).sum(-1)

    chamfer1 = distances.min(dim=1)[0].mean(dim=1)
    chamfer2 = distances.min(dim=2)[0].mean(dim=1)

    chamfer = chamfer1 + chamfer2
    return chamfer


def chamfer_distance_kdtree(points1, points2, give_id=False):
    ''' KD-tree based implementation of the Chamfer distance.

    Args:
        points1 (numpy array): first point set
        points2 (numpy array): second point set
        give_id (bool): whether to return the IDs of the nearest points
    '''
    # Points have size batch_size x T x 3
    batch_size = points1.size(0)

    # First convert points to numpy
    points1_np = points1.detach().cpu().numpy()
    points2_np = points2.detach().cpu().numpy()

    # Get list of nearest neighbors indices
    idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np)
    idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device)
    # Expands it as batch_size x 1 x 3
    idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1)

    # Get list of nearest neighbors indices
    idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np)
    idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device)
    # Expands it as batch_size x T x 3
    idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2)

    # Compute nearest neighbors in points2 to points in points1
    # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k]
    points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand)

    # Compute nearest neighbors in points1 to points in points2
    # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k]
    points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand)

    # Compute chamfer distance
    chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1)
    chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1)

    # Take sum
    chamfer = chamfer1 + chamfer2

    # If required, also return nearest neighbors
    if give_id:
        return chamfer1, chamfer2, idx_nn_12, idx_nn_21

    return chamfer


def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1):
    ''' Returns the nearest neighbors for point sets batchwise.

    Args:
        points_src (numpy array): source points
        points_tgt (numpy array): target points
        k (int): number of nearest neighbors to return
    '''
    indices = []
    distances = []

    for (p1, p2) in zip(points_src, points_tgt):
        kdtree = KDTree(p2)
        dist, idx = kdtree.query(p1, k=k)
        indices.append(idx)
        distances.append(dist)

    return indices, distances


def normalize_imagenet(x):
    ''' Normalize input images according to ImageNet standards.

    Args:
        x (tensor): input images
    '''
    x = x.clone()
    x[:, 0] = (x[:, 0] - 0.485) / 0.229
    x[:, 1] = (x[:, 1] - 0.456) / 0.224
    x[:, 2] = (x[:, 2] - 0.406) / 0.225
    return x


def make_3d_grid(bb_min, bb_max, shape):
    ''' Makes a 3D grid.

    Args:
        bb_min (tuple): bounding box minimum
        bb_max (tuple): bounding box maximum
        shape (tuple): output shape
    '''
    size = shape[0] * shape[1] * shape[2]

    pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
    pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
    pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])

    pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
    pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
    pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
    p = torch.stack([pxs, pys, pzs], dim=1)

    return p


def get_occupancy_loss_points(
    pixels,
    camera_mat,
    world_mat,
    scale_mat,
    depth_image=None,
    use_cube_intersection=True,
    occupancy_random_normal=False,
    depth_range=[0, 2.4]
):
    ''' Returns 3D points for occupancy loss.

    Args:
        pixels (tensor): sampled pixels in range [-1, 1]
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        depth_image tensor): if not None, these depth values are used for
            initialization (e.g. depth or visual hull depth)
        use_cube_intersection (bool): whether to check unit cube intersection
        occupancy_random_normal (bool): whether to sample from a Normal
            distribution instead of a uniform one
        depth_range (float): depth range; important when no cube
            intersection is used
    '''
    device = pixels.device
    batch_size, n_points, _ = pixels.shape

    if use_cube_intersection:
        _, d_cube_intersection, mask_cube = \
            intersect_camera_rays_with_unit_cube(
                pixels, camera_mat, world_mat, scale_mat, padding=0.,
                use_ray_length_as_depth=False)
        d_cube = d_cube_intersection[mask_cube]

    d_occupancy = torch.rand(batch_size, n_points).to(device) * depth_range[1]

    if use_cube_intersection:
        d_occupancy[mask_cube] = d_cube[:, 0] + \
            torch.rand(d_cube.shape[0]).to(
                device) * (d_cube[:, 1] - d_cube[:, 0])
    if occupancy_random_normal:
        d_occupancy = torch.randn(batch_size, n_points).to(device) \
            * (depth_range[1] / 8) + depth_range[1] / 2
        if use_cube_intersection:
            mean_cube = d_cube.sum(-1) / 2
            std_cube = (d_cube[:, 1] - d_cube[:, 0]) / 8
            d_occupancy[mask_cube] = mean_cube + \
                torch.randn(mean_cube.shape[0]).to(device) * std_cube

    if depth_image is not None:
        depth_gt, mask_gt_depth = get_tensor_values(
            depth_image, pixels, squeeze_channel_dim=True, with_mask=True
        )
        d_occupancy[mask_gt_depth] = depth_gt[mask_gt_depth]

    p_occupancy = transform_to_world(
        pixels, d_occupancy.unsqueeze(-1), camera_mat, world_mat, scale_mat
    )
    return p_occupancy


def get_freespace_loss_points(
    pixels, camera_mat, world_mat, scale_mat, use_cube_intersection=True, depth_range=[0, 2.4]
):
    ''' Returns 3D points for freespace loss.

    Args:
        pixels (tensor): sampled pixels in range [-1, 1]
        camera_mat (tensor): camera matrix
        world_mat (tensor): world matrix
        scale_mat (tensor): scale matrix
        use_cube_intersection (bool): whether to check unit cube intersection
        depth_range (float): depth range; important when no cube
            intersection is used
    '''
    device = pixels.device
    batch_size, n_points, _ = pixels.shape

    d_freespace = torch.rand(batch_size, n_points).to(device) * \
        depth_range[1]

    if use_cube_intersection:
        _, d_cube_intersection, mask_cube = \
            intersect_camera_rays_with_unit_cube(
                pixels, camera_mat, world_mat, scale_mat,
                use_ray_length_as_depth=False)
        d_cube = d_cube_intersection[mask_cube]
        d_freespace[mask_cube] = d_cube[:, 0] + \
            torch.rand(d_cube.shape[0]).to(
                device) * (d_cube[:, 1] - d_cube[:, 0])

    p_freespace = transform_to_world(
        pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat
    )
    return p_freespace


def normalize_tensor(tensor, min_norm=1e-5, feat_dim=-1):
    ''' Normalizes the tensor.

    Args:
        tensor (tensor): tensor
        min_norm (float): minimum norm for numerical stability
        feat_dim (int): feature dimension in tensor (default: -1)
    '''
    norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True), min=min_norm)
    normed_tensor = tensor / norm_tensor
    return normed_tensor