ctm-space / ctm /processors /processor_base.py
Haofei Yu
update the deployable ctm (#22)
084fe8e unverified
raw
history blame
5.13 kB
from typing import Any, Callable, Dict, Optional, Tuple, Type
from openai import OpenAI
from ..utils.decorator import score_exponential_backoff
class BaseProcessor(object):
_processor_registry: Dict[str, Type["BaseProcessor"]] = {}
@classmethod
def register_processor(
cls, processor_name: str
) -> Callable[[Type["BaseProcessor"]], Type["BaseProcessor"]]:
def decorator(
subclass: Type["BaseProcessor"],
) -> Type["BaseProcessor"]:
cls._processor_registry[processor_name] = subclass
return subclass
return decorator
def __new__(
cls, processor_name: str, *args: Any, **kwargs: Any
) -> "BaseProcessor":
if processor_name not in cls._processor_registry:
raise ValueError(
f"No processor registered with name '{processor_name}'"
)
return super(BaseProcessor, cls).__new__(
cls._processor_registry[processor_name]
)
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.init_scorer()
self.init_executor()
self.init_messenger()
self.init_task_info()
def init_executor(self) -> None:
raise NotImplementedError(
"The 'init_executor' method must be implemented in derived classes."
)
def init_messenger(self) -> None:
raise NotImplementedError(
"The 'init_messenger' method must be implemented in derived classes."
)
def init_task_info(self) -> None:
raise NotImplementedError(
"The 'init_task_info' method must be implemented in derived classes."
)
def init_scorer(self) -> None:
self.scorer = OpenAI()
def ask(
self, query: str, text: str, image: str, audio: str, video_frames: str
) -> Tuple[str, float]:
gist = self.ask_info(
query=query,
text=text,
image=image,
audio=audio,
video_frames=video_frames,
)
score = self.ask_score(query, gist, verbose=True)
return gist, score
@score_exponential_backoff(retries=5, base_wait_time=1)
def ask_relevance(self, query: str, gist: str) -> float:
response = self.scorer.chat.completions.create(
model="gpt-4-0125-preview",
messages=[
{
"role": "user",
"content": f"How related is the information ({gist}) with the query ({query})? Answer with a number from 0 to 5 and do not add any other thing.",
}
],
max_tokens=50,
)
score = (
float(response.choices[0].message.content.strip()) / 5
if response.choices[0].message.content
else 0.0
)
return score
@score_exponential_backoff(retries=5, base_wait_time=1)
def ask_confidence(self, query: str, gist: str) -> float:
response = self.scorer.chat.completions.create(
model="gpt-4-0125-preview",
messages=[
{
"role": "user",
"content": f"How confident do you think the information ({gist}) is a must-know? Answer with a number from 0 to 5 and do not add any other thing.",
}
],
max_tokens=50,
)
score = (
float(response.choices[0].message.content.strip()) / 5
if response.choices[0].message.content
else 0.0
)
return score
@score_exponential_backoff(retries=5, base_wait_time=1)
def ask_surprise(
self, query: str, gist: str, history_gists: Optional[str] = None
) -> float:
response = self.scorer.chat.completions.create(
model="gpt-4-0125-preview",
messages=[
{
"role": "user",
"content": f"How surprising do you think the information ({gist}) is as an output of the processor? Answer with a number from 0 to 5 and do not add any other thing.",
}
],
max_tokens=50,
)
score = (
float(response.choices[0].message.content.strip()) / 5
if response.choices[0].message.content
else 0.0
)
return score
def ask_score(
self,
query: str,
gist: str,
verbose: bool = False,
*args: Any,
**kwargs: Any,
) -> float:
relevance = self.ask_relevance(query, gist, *args, **kwargs)
confidence = self.ask_confidence(query, gist, *args, **kwargs)
surprise = self.ask_surprise(query, gist, *args, **kwargs)
if verbose:
print(
f"Relevance: {relevance}, Confidence: {confidence}, Surprise: {surprise}"
)
final_score = relevance * confidence * surprise
return final_score
def ask_info(self, *args: Any, **kwargs: Any) -> str:
raise NotImplementedError(
"The 'ask_info' method must be implemented in derived classes."
)