kernel
File size: 485 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#pragma once

#include <torch/all.h>

namespace megablocks {

// Forward pass: replicate values from x according to bin sizes
void replicate_forward(torch::Tensor x,
                       torch::Tensor bins,
                       torch::Tensor out);

// Backward pass: reduce gradients back to bins using segmented reduction
void replicate_backward(torch::Tensor grad,
                        torch::Tensor bins,
                        torch::Tensor out);

} // namespace megablocks