Spaces:
Running
on
T4
Running
on
T4
from typing import Optional, Callable | |
from abc import ABC, abstractmethod | |
from app import ModelBuilder | |
''' | |
Direct the ModelBuilder to build a model depending on the modal the user choose | |
''' | |
class ModelDirector(ABC): | |
def __init__( | |
self, | |
builder: ModelBuilder = None, | |
additional_setup_fn: Optional[Callable[['ModelBuilder'], None]] = None | |
): | |
if builder is None: | |
self._builder = ModelBuilder() | |
else: | |
self._builder = builder | |
self._additional_setup_fn = additional_setup_fn | |
self._ae_weights = self.get_ae_weights() | |
self._diffusion_weights = self.get_diffusion_weights() | |
self._condition = self.get_generating_condition() | |
def config_setup(self): | |
self._builder.setup_autoencoder_weights(self._ae_weights) | |
self._builder.setup_diffusion_weights(self._diffusion_weights) | |
# User defined setup | |
if self._additional_setup_fn: | |
self._additional_setup_fn(self._builder) | |
self._builder.setup_condition(self._condition) | |
def buider(self): | |
return self._builder | |
def get_ae_weights(self): | |
pass | |
def get_diffusion_weights(self): | |
pass | |
def get_generating_condition(self): | |
pass | |