Spaces:
Running
on
L40S
Running
on
L40S
from typing import * | |
BACKEND = 'flash_attn' | |
DEBUG = False | |
def __from_env(): | |
import os | |
global BACKEND | |
global DEBUG | |
env_attn_backend = os.environ.get('ATTN_BACKEND') | |
env_sttn_debug = os.environ.get('ATTN_DEBUG') | |
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: | |
BACKEND = env_attn_backend | |
if env_sttn_debug is not None: | |
DEBUG = env_sttn_debug == '1' | |
print(f"[ATTENTION] Using backend: {BACKEND}") | |
__from_env() | |
def set_backend(backend: Literal['xformers', 'flash_attn']): | |
global BACKEND | |
BACKEND = backend | |
def set_debug(debug: bool): | |
global DEBUG | |
DEBUG = debug | |
from .full_attn import * | |
from .modules import * | |