from abc import ABC, abstractmethod import gradio as gr from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp from mammal.model import Mammal class MammalObjectBroker: def __init__( self, model_path: str, name: str | None = None, task_list: list[str] | None = None, ) -> None: self.model_path = model_path if name is None: name = model_path self.name = name self.tasks: list[str] = [] if task_list is not None: self.tasks = task_list self._model: Mammal | None = None self._tokenizer_op = None @property def model(self) -> Mammal: if self._model is None: self._model = Mammal.from_pretrained(self.model_path) self._model.eval() return self._model @property def tokenizer_op(self): if self._tokenizer_op is None: self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) return self._tokenizer_op class MammalTask(ABC): def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None: self.name = name self.description = None self._demo = None self.model_dict = model_dict # @abstractmethod # def _generate_prompt(self, **kwargs) -> str: # """Formatting prompt to match pre-training syntax # Args: # prot1 (_type_): _description_ # prot2 (_type_): _description_ # Raises: # No: _description_ # """ # raise NotImplementedError() @abstractmethod def crate_sample_dict( self, sample_inputs: dict, model_holder: MammalObjectBroker ) -> dict: """Formatting prompt to match pre-training syntax Args: prompt (str): _description_ Returns: dict: sample_dict for feeding into model """ raise NotImplementedError() # @abstractmethod def run_model(self, sample_dict, model: Mammal): raise NotImplementedError() def create_demo(self, model_name_widget: gr.component) -> gr.Group: """create an gradio demo group Args: model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create gradio actions with the current model name as an input Raises: NotImplementedError: _description_ """ raise NotImplementedError() def demo(self, model_name_widgit: gr.component = None): if self._demo is None: self._demo = self.create_demo(model_name_widget=model_name_widgit) return self._demo @abstractmethod def decode_output(self, batch_dict, model: Mammal) -> list: raise NotImplementedError() # self._setup() # def _setup(self): # pass class TaskRegistry(dict[str, MammalTask]): """just a dictionary with a register method""" def register_task(self, task: MammalTask): self[task.name] = task return task.name class ModelRegistry(dict[str, MammalObjectBroker]): """just a dictionary with a register models""" def register_model(self, model_path, task_list=None, name=None): """register a model and return the name of the model Args: model_path (_type_): _description_ name (optional str): explicit name for the model Returns: str: model name """ model_holder = MammalObjectBroker( model_path=model_path, task_list=task_list, name=name ) self[model_holder.name] = model_holder return model_holder.name