| # -*- coding: utf-8 -*- | |
| # Copyright (c) 2024, Songlin Yang, Yu Zhang | |
| import os | |
| import triton | |
| import triton.language as tl | |
| import triton.language.extra.libdevice as tldevice | |
| from fla.utils import is_gather_supported | |
| if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': | |
| div = tldevice.fast_dividef | |
| exp = tldevice.fast_expf | |
| log = tldevice.fast_logf | |
| log2 = tldevice.fast_log2f | |
| else: | |
| def div_normal(x, y): | |
| return x / y | |
| div = div_normal | |
| exp = tl.exp | |
| log = tl.log | |
| log2 = tl.log2 | |
| def safe_exp(x): | |
| return exp(tl.where(x <= 0, x, float('-inf'))) | |
| if not is_gather_supported: | |
| def gather(*args, **kwargs): | |
| pass | |
| else: | |
| gather = tl.gather | |