Spaces:
Runtime error
Runtime error
from contextlib import contextmanager | |
from functools import update_wrapper | |
import os | |
import threading | |
import torch | |
def get_use_compile(): | |
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" | |
def get_use_flash_attention_2(): | |
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" | |
state = threading.local() | |
state.checkpointing = False | |
def checkpointing(enable=True): | |
try: | |
old_checkpointing, state.checkpointing = state.checkpointing, enable | |
yield | |
finally: | |
state.checkpointing = old_checkpointing | |
def get_checkpointing(): | |
return getattr(state, "checkpointing", False) | |
class compile_wrap: | |
def __init__(self, function, *args, **kwargs): | |
self.function = function | |
self.args = args | |
self.kwargs = kwargs | |
self._compiled_function = None | |
update_wrapper(self, function) | |
def compiled_function(self): | |
if self._compiled_function is not None: | |
return self._compiled_function | |
if get_use_compile(): | |
try: | |
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) | |
except RuntimeError: | |
self._compiled_function = self.function | |
else: | |
self._compiled_function = self.function | |
return self._compiled_function | |
def __call__(self, *args, **kwargs): | |
return self.compiled_function(*args, **kwargs) | |