|
from torch.utils.checkpoint import checkpoint |
|
|
|
import ldm.modules.attention |
|
import ldm.modules.diffusionmodules.openaimodel |
|
|
|
|
|
def BasicTransformerBlock_forward(self, x, context=None): |
|
return checkpoint(self._forward, x, context) |
|
|
|
|
|
def AttentionBlock_forward(self, x): |
|
return checkpoint(self._forward, x) |
|
|
|
|
|
def ResBlock_forward(self, x, emb): |
|
return checkpoint(self._forward, x, emb) |
|
|
|
|
|
stored = [] |
|
|
|
|
|
def add(): |
|
if len(stored) != 0: |
|
return |
|
|
|
stored.extend([ |
|
ldm.modules.attention.BasicTransformerBlock.forward, |
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, |
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward |
|
]) |
|
|
|
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward |
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward |
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward |
|
|
|
|
|
def remove(): |
|
if len(stored) == 0: |
|
return |
|
|
|
ldm.modules.attention.BasicTransformerBlock.forward = stored[0] |
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] |
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] |
|
|
|
stored.clear() |
|
|
|
|