BeamDiffusion / models /diffusionModel /Latents_Singleton.py
Gui28F's picture
uploaded all project files
173ea2b verified
import torch
from typing import List
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Latents(metaclass=Singleton):
def __init__(self) -> None:
self.history: List[torch.FloatTensor] = []
def is_empty(self) -> bool:
return self.history is None
def add_latents(self, latents: torch.FloatTensor):
self.history.append(latents)
def clear(self):
self.history = []
def dump_and_clear(self):
history = self.history
self.clear()
return history