custom GEGLU implementation

#32
by brwang - opened

I was looking at this function:

def gegelu(input, limit: Optional[float] = None):
    a_gelu, a_linear = input[..., ::2], input[..., 1::2]
    if limit is not None:
        a_gelu = torch.where(
            torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
        )
        a_linear = torch.where(
            torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
        )
    out_gelu = quick_gelu(a_gelu)
    return out_gelu * (a_linear + 1)

I'm wondering why:

  1. There is (a_linear + 1) term, usually would just be a_linear in GEGLU.
  2. a_gelu, a_linear are split in an interleaved way, instead of splitting the input tensor in half.

Thank you in advance! I tried looking these things up but there isn't anything on the internet about it.

Sign up or log in to comment