from torch import Tensor class DummyCondStage: def __init__(self, conditional_key): self.conditional_key = conditional_key self.train = None def eval(self): return self @staticmethod def encode(c: Tensor): return c, None, (None, None, c) @staticmethod def decode(c: Tensor): return c @staticmethod def to_rgb(c: Tensor): return c