import torch.nn as nn | |
import torch.nn.functional as F | |
try: | |
from flash_attn.ops.activations import swiglu as flash_swiglu | |
except ImportError: | |
flash_swiglu = None | |
if flash_swiglu is None: | |
# PyTorch implementation of SwiGLU | |
class SwiGLU(nn.Module): | |
def forward(self, x): | |
x, gate = x.chunk(2, dim=-1) | |
return F.silu(gate) * x | |
def swiglu(x): | |
layer = SwiGLU() | |
return layer(x) | |
else: | |
# Use Flash Attention's built-in swiglu | |
def swiglu(x): | |
return flash_swiglu(x) | |