FNet 1D FFT NKI Kernel for AWS Neuron

NeuronCore-optimized NKI kernels that perform 1D FFT (128/256/512-point) entirely on NeuronCore hardware using Tensor Engine matrix multiplications. Used here to replace the FNetBasicFourierTransform module in FNet models, which requires a 2D FFT decomposed into two passes of 1D FFTs.

Why This Kernel Exists

torch.fft.fftn is not supported on Neuron hardware (complex dtypes not natively supported). Without this kernel, the FFT operation requires:

  1. Device-to-Host transfer (~hundreds of ms)
  2. CPU computation
  3. Host-to-Device transfer (~hundreds of ms)

This kernel eliminates those transfers by computing the FFT entirely on-device.

Supported Models

All FNet models using FNetBasicFourierTransform:

Model Hidden Size Max Seq Len Tested
google/fnet-base 768 512 Tested: cos_sim=1.000001
google/fnet-large 1024 512 Supported (power-of-2 hidden)
Custom configs 128-2048 128-512 Supported

Requirements

  • AWS Neuron SDK 2.29+ (NKI 0.3.0+)
  • PyTorch Native (torch-neuronx) for eager execution of the orchestration layer
  • transformers with KernelConfig support
  • kernels library (HuggingFace kernel loading)

Note on framework compatibility: The NKI kernels themselves (_fft1d_128, _fft1d_256, _fft1d_512) use only ISA-level operations (nisa.*) and are compatible with both PyTorch Native and standard torch-neuronx (XLA trace path). However, the orchestration layer (nki_fft2d_real) uses eager Python loops and tensor operations, which requires PyTorch Native (eager mode). To use these kernels in an XLA trace context, the orchestration layer would need to be rewritten to be trace-compatible.

Usage

from transformers import AutoModel, AutoTokenizer, KernelConfig

model_id = "google/fnet-base"

kernel_config = KernelConfig({
    "FNetBasicFourierTransform": "jburtoft/fnet-neuron-kernels:NeuronFNetFourierForward",
})

model = AutoModel.from_pretrained(
    model_id,
    kernel_config=kernel_config,
    device_map="neuron",
)

Performance

Benchmarked on trn2.3xlarge (PyTorch Native Beta 3, SDK 2.30):

Config NKI Kernel CPU Fallback (D2H+FFT+H2D) Speedup
B=1 S=128 D=512 6.2 ms 462 ms 74.6x
B=1 S=128 D=768 (FNet-base) 8.0 ms 692 ms 86.1x
B=1 S=256 D=768 12.0 ms 1390 ms 116.3x
B=1 S=512 D=768 (full seq) 18.2 ms 2779 ms 152.4x
B=4 S=128 D=768 (batched) 32.5 ms 2779 ms 85.5x

Accuracy: cosine similarity = 1.000001-1.000007 across all configurations.

How It Works

FNet applies torch.fft.fftn(hidden_states, dim=(1,2)).real — a 2D FFT. This is decomposed into two passes of 1D FFTs:

Input: hidden_states (B, S, D)

Pass 1: 1D FFT along hidden_dim (D) for each row
Pass 2: 1D FFT along seq_dim (S) for each column
Output: real part of 2D FFT result

Each 1D FFT uses a flat radix-2 Cooley-Tukey algorithm:

  1. 128-pt base case: Complex DFT via nc_matmul on the 128x128 Tensor Engine (~90% utilization)
  2. 256-pt: 2 groups of 128 + 1 butterfly level with twiddle factors
  3. 512-pt: 4 groups of 128 + 2 butterfly levels

For non-power-of-2 dimensions (e.g., D=768), the input is padded to the next power of 2 (1024) and the result truncated.

Architecture

NeuronFNetFourierLayout    - Weight structure (empty - no learnable params)
NeuronFNetFourierForward   - Forward pass (calls nki_fft2d_real)
nki_fft2d_real             - Orchestration: two 1D FFT passes (eager Python)
_fft1d_128/256/512         - @nki.jit kernels (radix-2 Cooley-Tukey, ISA-level)
_fft1d_matmul_isa          - 128-pt DFT via nc_matmul (Tensor Engine)

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support