Spaces:
Running
on
Zero
Running
on
Zero
| """k-diffusion transformer diffusion models, version 2. | |
| Codes adopted from https://github.com/crowsonkb/k-diffusion | |
| """ | |
| from contextlib import contextmanager | |
| from functools import update_wrapper | |
| import os | |
| import threading | |
| import torch | |
| def get_use_compile(): | |
| return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" | |
| def get_use_flash_attention_2(): | |
| return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" | |
| state = threading.local() | |
| state.checkpointing = False | |
| def checkpointing(enable=True): | |
| try: | |
| old_checkpointing, state.checkpointing = state.checkpointing, enable | |
| yield | |
| finally: | |
| state.checkpointing = old_checkpointing | |
| def get_checkpointing(): | |
| return getattr(state, "checkpointing", False) | |
| class compile_wrap: | |
| def __init__(self, function, *args, **kwargs): | |
| self.function = function | |
| self.args = args | |
| self.kwargs = kwargs | |
| self._compiled_function = None | |
| update_wrapper(self, function) | |
| def compiled_function(self): | |
| if self._compiled_function is not None: | |
| return self._compiled_function | |
| if get_use_compile(): | |
| try: | |
| self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) | |
| except RuntimeError: | |
| self._compiled_function = self.function | |
| else: | |
| self._compiled_function = self.function | |
| return self._compiled_function | |
| def __call__(self, *args, **kwargs): | |
| return self.compiled_function(*args, **kwargs) |