custom GEGLU implementation
#32
by
brwang
- opened
I was looking at this function:
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
a_gelu = torch.where(
torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
)
a_linear = torch.where(
torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
)
out_gelu = quick_gelu(a_gelu)
return out_gelu * (a_linear + 1)
I'm wondering why:
- There is (a_linear + 1) term, usually would just be a_linear in GEGLU.
- a_gelu, a_linear are split in an interleaved way, instead of splitting the input tensor in half.
Thank you in advance! I tried looking these things up but there isn't anything on the internet about it.