Transformers documentation
Writing kernels
Writing kernels
This guide explains how to write kernels that go beyond a stateless forward replacement. It covers two capabilities the extended KernelConfig API supports:
- Parameter transformation: the kernel expects weights in a different layout than the original model (for example, renamed or merged parameters).
- Module fusion: the kernel replaces multiple adjacent modules with a single fused implementation.
For basic kernels (stateless forward replacements with no parameter changes), see the kernels library documentation.
Two-class pattern
Any kernel that carries its own parameters follows a two-class pattern.
KernelName: contains only theforwardpass. Thekernelslibrary uses this class to kernelize the model because it does not allow stateful kernel classes.KernelNameLayout: annn.Modulethat holds the parameters and monkey-patches the original module before the checkpoint is loaded. At runtime,kernelizereplaces itsforwardwith theforwardfromKernelName’. You do not need to defineforward. Transformers injects one automatically with the same signature asKernelName.forward.
The naming convention is strict. The layout class must be named {KernelName}Layout and defined in the same module as KernelName.
Parameter transformation
Use this pattern when the kernel expects weights under different names or in a different shape than the original model checkpoint.
The KernelNameLayout class has the same __init__ signature as the module it replaces and declares a conversion_mapping class attribute that tells Transformers how to remap checkpoint keys to the new parameter names (see Dynamic weight loading for more details).
import torch
import torch.nn as nn
class CustomRMSNormLayout(nn.Module):
conversion_mapping = [...] # rules that remap checkpoint keys to the new parameter names
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.scale = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
class CustomRMSNorm(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.scale * hidden_states.to(input_dtype)
class layers:
CustomRMSNorm = CustomRMSNormThe
layersclass is required by thekernelslibrary to expose the kernel entry point.
Load this kernel by passing the repo and class name to KernelConfig. The key is the original module class name from the model. The value points to the KernelName class (not the Layout) in the repo.
from transformers import AutoModelForCausalLM, KernelConfig
kernel_config = KernelConfig({"RMSNorm": "owner/my-kernel:CustomRMSNorm"})
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
use_kernels=True,
kernel_config=kernel_config,
device_map="cuda",
)When the model loads, Transformers:
- Loads
CustomRMSNormfrom the repo and looks forCustomRMSNormLayoutin the same module. - Monkey-patches every
RMSNormin the model withCustomRMSNormLayout. - Remaps checkpoint weights using
conversion_mappingso they load into the new parameter names. - Calls
kernelize, which replacesCustomRMSNormLayout.forwardwithCustomRMSNorm.forward.
Module fusion
Use this pattern when a kernel replaces multiple adjacent modules with a single fused implementation. Because the fused module combines parameters from several original modules, the KernelNameLayout.__init__ receives the instantiated child modules rather than their constructor arguments.
import torch
import torch.nn as nn
class RMSNormMLPLayout(nn.Module):
conversion_mapping = [...] # rules that remap checkpoint keys to the fused parameter names
def __init__(self, norm, mlp):
super().__init__()
self.variance_epsilon = norm.variance_epsilon
self.scale = nn.Parameter(torch.empty_like(norm.weight))
self.gate_up_proj = nn.Linear(
mlp.gate_proj.in_features,
mlp.gate_proj.out_features + mlp.up_proj.out_features,
bias=mlp.gate_proj.bias is not None,
device=mlp.gate_proj.weight.device,
dtype=mlp.gate_proj.weight.dtype,
)
self.down_proj = nn.Linear(
mlp.down_proj.in_features,
mlp.down_proj.out_features,
bias=mlp.down_proj.bias is not None,
device=mlp.down_proj.weight.device,
dtype=mlp.down_proj.weight.dtype,
)
self.act_fn = mlp.act_fn
class RMSNormMLP(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.scale * hidden_states.to(input_dtype)
gate, up = self.gate_up_proj(hidden_states).chunk(2, dim=-1)
return self.down_proj(self.act_fn(gate) * up)
class layers:
RMSNormMLP = RMSNormMLPTo fuse modules, pass a tuple of (class_name, path_pattern) pairs as the key in KernelConfig instead of a plain string. All patterns must share the same parent module (Transformers fuses the children in that parent). The * wildcard matches any single path segment.
from transformers import AutoModelForCausalLM, KernelConfig
kernel_config = KernelConfig(
{
(
("RMSNorm", "model.layers.*.post_attention_layernorm"),
("MLP", "model.layers.*.mlp"),
): "owner/my-kernel:RMSNormMLP",
}
)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
use_kernels=True,
kernel_config=kernel_config,
device_map="cuda",
)When the model loads, Transformers:
- Loads
RMSNormMLPfrom the repo and findsRMSNormMLPLayoutin the same module. - Matches every decoder layer at
model.layers.*and builds a fused parent class whose__init__callsRMSNormMLPLayout(post_attention_layernorm, mlp). - Replaces the remaining child (
mlp) withnn.Identity()to preserve the parent module’s interface. - Remaps checkpoint weights using
conversion_mapping. - Calls
kernelize, which replacesRMSNormMLPLayout.forwardwithRMSNormMLP.forward.
Update on GitHubThe order of pairs in the fusion tuple determines the argument order passed to
KernelNameLayout.__init__.