import torch.nn as nn import torch.nn.functional as F try: from flash_attn.ops.activations import swiglu as flash_swiglu except ImportError: flash_swiglu = None if flash_swiglu is None: # PyTorch implementation of SwiGLU class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x def swiglu(x): layer = SwiGLU() return layer(x) else: # Use Flash Attention's built-in swiglu def swiglu(x): return flash_swiglu(x)