Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp | |
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask | |
from mammal.keys import * | |
from mammal.model import Mammal | |
from abc import ABC, abstractmethod | |
class MammalObjectBroker(): | |
def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None: | |
self.model_path = model_path | |
if name is None: | |
name = model_path | |
self.name = name | |
if task_list is not None: | |
self.tasks=task_list | |
else: | |
self.task = [] | |
self._model = None | |
self._tokenizer_op = None | |
def model(self)-> Mammal: | |
if self._model is None: | |
self._model = Mammal.from_pretrained(self.model_path) | |
self._model.eval() | |
return self._model | |
def tokenizer_op(self): | |
if self._tokenizer_op is None: | |
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) | |
return self._tokenizer_op | |
class MammalTask(ABC): | |
def __init__(self, name:str) -> None: | |
self.name = name | |
self.description = None | |
self._demo = None | |
# @abstractmethod | |
# def _generate_prompt(self, **kwargs) -> str: | |
# """Formatting prompt to match pre-training syntax | |
# Args: | |
# prot1 (_type_): _description_ | |
# prot2 (_type_): _description_ | |
# Raises: | |
# No: _description_ | |
# """ | |
# raise NotImplementedError() | |
def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict: | |
"""Formatting prompt to match pre-training syntax | |
Args: | |
prompt (str): _description_ | |
Returns: | |
dict: sample_dict for feeding into model | |
""" | |
raise NotImplementedError() | |
# @abstractmethod | |
def run_model(self, sample_dict, model:Mammal): | |
raise NotImplementedError() | |
def create_demo(self, model_name_widget: gr.component) -> gr.Group: | |
"""create an gradio demo group | |
Args: | |
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create | |
gradio actions with the current model name as an input | |
Raises: | |
NotImplementedError: _description_ | |
""" | |
raise NotImplementedError() | |
def demo(self,model_name_widgit:gr.component=None): | |
if self._demo is None: | |
model_name_widget:gr.component | |
self._demo = self.create_demo(model_name_widget=model_name_widgit) | |
return self._demo | |
def decode_output(self,batch_dict, model:Mammal): | |
raise NotImplementedError() | |
#self._setup() | |
# def _setup(self): | |
# pass | |