File size: 1,614 Bytes
acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e acb3380 084fe8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import base64
from typing import Any, Dict, Optional, Tuple, Type
class BaseSupervisor(object):
_supervisor_registry: Dict[str, Type["BaseSupervisor"]] = {}
@classmethod
def register_supervisor(cls, supervisor_name: str) -> Any:
def decorator(
subclass: Type["BaseSupervisor"],
) -> Type["BaseSupervisor"]:
cls._supervisor_registry[supervisor_name] = subclass
return subclass
return decorator
def __new__(cls, supervisor_name: str, *args: Any, **kwargs: Any) -> Any:
if supervisor_name not in cls._supervisor_registry:
raise ValueError(
f"No supervisor registered with name '{supervisor_name}'"
)
return super(BaseSupervisor, cls).__new__(
cls._supervisor_registry[supervisor_name]
)
def set_model(
self,
) -> None:
raise NotImplementedError(
"The 'set_model' method must be implemented in derived classes."
)
def ask(self, query: str, image_path: str) -> Tuple[str, float]:
gist = self.ask_info(query, image_path)
score = self.ask_score(query, gist, verbose=True)
return gist, score
def ask_info(self, query: str, context: Optional[str] = None) -> str:
raise NotImplementedError(
"The 'ask_info' method must be implemented in derived classes."
)
def ask_score(self, query: str, gist: str, verbose: bool = False) -> float:
raise NotImplementedError(
"The 'ask_score' method must be implemented in derived classes."
)
|