FNet 1D FFT NKI Kernel for AWS Neuron
NeuronCore-optimized NKI kernels that perform 1D FFT (128/256/512-point) entirely on NeuronCore hardware using Tensor Engine matrix multiplications. Used here to replace the FNetBasicFourierTransform module in FNet models, which requires a 2D FFT decomposed into two passes of 1D FFTs.
Why This Kernel Exists
torch.fft.fftn is not supported on Neuron hardware (complex dtypes not natively supported). Without this kernel, the FFT operation requires:
- Device-to-Host transfer (~hundreds of ms)
- CPU computation
- Host-to-Device transfer (~hundreds of ms)
This kernel eliminates those transfers by computing the FFT entirely on-device.
Supported Models
All FNet models using FNetBasicFourierTransform:
| Model | Hidden Size | Max Seq Len | Tested |
|---|---|---|---|
| google/fnet-base | 768 | 512 | Tested: cos_sim=1.000001 |
| google/fnet-large | 1024 | 512 | Supported (power-of-2 hidden) |
| Custom configs | 128-2048 | 128-512 | Supported |
Requirements
- AWS Neuron SDK 2.29+ (NKI 0.3.0+)
- PyTorch Native (torch-neuronx) for eager execution of the orchestration layer
transformerswithKernelConfigsupportkernelslibrary (HuggingFace kernel loading)
Note on framework compatibility: The NKI kernels themselves (_fft1d_128, _fft1d_256, _fft1d_512) use only ISA-level operations (nisa.*) and are compatible with both PyTorch Native and standard torch-neuronx (XLA trace path). However, the orchestration layer (nki_fft2d_real) uses eager Python loops and tensor operations, which requires PyTorch Native (eager mode). To use these kernels in an XLA trace context, the orchestration layer would need to be rewritten to be trace-compatible.
Usage
from transformers import AutoModel, AutoTokenizer, KernelConfig
model_id = "google/fnet-base"
kernel_config = KernelConfig({
"FNetBasicFourierTransform": "jburtoft/fnet-neuron-kernels:NeuronFNetFourierForward",
})
model = AutoModel.from_pretrained(
model_id,
kernel_config=kernel_config,
device_map="neuron",
)
Performance
Benchmarked on trn2.3xlarge (PyTorch Native Beta 3, SDK 2.30):
| Config | NKI Kernel | CPU Fallback (D2H+FFT+H2D) | Speedup |
|---|---|---|---|
| B=1 S=128 D=512 | 6.2 ms | 462 ms | 74.6x |
| B=1 S=128 D=768 (FNet-base) | 8.0 ms | 692 ms | 86.1x |
| B=1 S=256 D=768 | 12.0 ms | 1390 ms | 116.3x |
| B=1 S=512 D=768 (full seq) | 18.2 ms | 2779 ms | 152.4x |
| B=4 S=128 D=768 (batched) | 32.5 ms | 2779 ms | 85.5x |
Accuracy: cosine similarity = 1.000001-1.000007 across all configurations.
How It Works
FNet applies torch.fft.fftn(hidden_states, dim=(1,2)).real — a 2D FFT. This is decomposed into two passes of 1D FFTs:
Input: hidden_states (B, S, D)
Pass 1: 1D FFT along hidden_dim (D) for each row
Pass 2: 1D FFT along seq_dim (S) for each column
Output: real part of 2D FFT result
Each 1D FFT uses a flat radix-2 Cooley-Tukey algorithm:
- 128-pt base case: Complex DFT via
nc_matmulon the 128x128 Tensor Engine (~90% utilization) - 256-pt: 2 groups of 128 + 1 butterfly level with twiddle factors
- 512-pt: 4 groups of 128 + 2 butterfly levels
For non-power-of-2 dimensions (e.g., D=768), the input is padded to the next power of 2 (1024) and the result truncated.
Architecture
NeuronFNetFourierLayout - Weight structure (empty - no learnable params)
NeuronFNetFourierForward - Forward pass (calls nki_fft2d_real)
nki_fft2d_real - Orchestration: two 1D FFT passes (eager Python)
_fft1d_128/256/512 - @nki.jit kernels (radix-2 Cooley-Tukey, ISA-level)
_fft1d_matmul_isa - 128-pt DFT via nc_matmul (Tensor Engine)
License
Apache 2.0