File size: 485 Bytes
2595c46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
#pragma once
#include <torch/all.h>
namespace megablocks {
// Forward pass: replicate values from x according to bin sizes
void replicate_forward(torch::Tensor x,
torch::Tensor bins,
torch::Tensor out);
// Backward pass: reduce gradients back to bins using segmented reduction
void replicate_backward(torch::Tensor grad,
torch::Tensor bins,
torch::Tensor out);
} // namespace megablocks |