from torch import Tensor | |
class DummyCondStage: | |
def __init__(self, conditional_key): | |
self.conditional_key = conditional_key | |
self.train = None | |
def eval(self): | |
return self | |
def encode(c: Tensor): | |
return c, None, (None, None, c) | |
def decode(c: Tensor): | |
return c | |
def to_rgb(c: Tensor): | |
return c | |