bwang0911's picture
feat: move to activation file
55213ef
raw
history blame
541 Bytes
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn.ops.activations import swiglu as flash_swiglu
except ImportError:
flash_swiglu = None
if flash_swiglu is None:
# PyTorch implementation of SwiGLU
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
def swiglu(x):
layer = SwiGLU()
return layer(x)
else:
# Use Flash Attention's built-in swiglu
def swiglu(x):
return flash_swiglu(x)