Spaces:
Sleeping
Sleeping
File size: 1,242 Bytes
79df973 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# -*- coding: utf-8 -*-
#
# @File: __init__.py
# @Author: Haozhe Xie
# @Date: 2023-03-24 20:24:38
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-06-16 09:55:58
# @Email: root@haozhexie.com
import torch
import extrude_tensor_ext
class TensorExtruder(torch.nn.Module):
def __init__(self, max_height=256):
super(TensorExtruder, self).__init__()
self.max_height = max_height
def forward(self, seg_map, height_field):
assert torch.max(height_field) < self.max_height, "Max Value %d" % torch.max(
height_field
)
return ExtrudeTensorFunction.apply(seg_map, height_field, self.max_height)
class ExtrudeTensorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, seg_map, height_field, max_height):
# seg_map.shape: (B, C, H, W)
# height_field.shape: (B, C, H, W)
return extrude_tensor_ext.forward(seg_map, height_field, max_height)
@staticmethod
def backward(ctx, grad_volume):
# grad_volume.shape: (B, C, H, W, D)
# Combine the gradients along the Z-axis.
grad_seg_map = torch.sum(grad_volume, dim=4)
grad_height_field = grad_seg_map
return grad_seg_map, grad_height_field
|