Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# | |
# @File: __init__.py | |
# @Author: Haozhe Xie | |
# @Date: 2023-03-24 20:24:38 | |
# @Last Modified by: Haozhe Xie | |
# @Last Modified at: 2023-06-16 09:55:58 | |
# @Email: root@haozhexie.com | |
import torch | |
import extrude_tensor_ext | |
class TensorExtruder(torch.nn.Module): | |
def __init__(self, max_height=256): | |
super(TensorExtruder, self).__init__() | |
self.max_height = max_height | |
def forward(self, seg_map, height_field): | |
assert torch.max(height_field) < self.max_height, "Max Value %d" % torch.max( | |
height_field | |
) | |
return ExtrudeTensorFunction.apply(seg_map, height_field, self.max_height) | |
class ExtrudeTensorFunction(torch.autograd.Function): | |
def forward(ctx, seg_map, height_field, max_height): | |
# seg_map.shape: (B, C, H, W) | |
# height_field.shape: (B, C, H, W) | |
return extrude_tensor_ext.forward(seg_map, height_field, max_height) | |
def backward(ctx, grad_volume): | |
# grad_volume.shape: (B, C, H, W, D) | |
# Combine the gradients along the Z-axis. | |
grad_seg_map = torch.sum(grad_volume, dim=4) | |
grad_height_field = grad_seg_map | |
return grad_seg_map, grad_height_field | |