import torch from typing import List class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] class Latents(metaclass=Singleton): def __init__(self) -> None: self.history: List[torch.FloatTensor] = [] def is_empty(self) -> bool: return self.history is None def add_latents(self, latents: torch.FloatTensor): self.history.append(latents) def clear(self): self.history = [] def dump_and_clear(self): history = self.history self.clear() return history