ctm-space / ctm /processors /processor_gpt4.py
Haofei Yu
update the deployable ctm (#22)
084fe8e unverified
raw
history blame
2.38 kB
from typing import Any, Dict, Optional
from openai import OpenAI
from ctm.messengers.messenger_base import BaseMessenger
from ctm.processors.processor_base import BaseProcessor
from ctm.utils.decorator import info_exponential_backoff
# Assuming the `register_processor` method has been updated to be properly typed:
@BaseProcessor.register_processor("gpt4_processor")
class GPT4Processor(BaseProcessor):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def init_task_info(self) -> None:
raise NotImplementedError(
"The 'init_task_info' method must be implemented in derived classes."
)
def init_executor(self) -> None:
self.executor = OpenAI()
def init_messenger(self) -> None:
self.messenger = BaseMessenger("gpt4_messenger")
def process(self, payload: Dict[str, Any]) -> Dict[str, Any]:
# Assume process should do something and return a dictionary
return {}
def update_info(self, feedback: str) -> None:
self.messenger.add_assistant_message(feedback)
@info_exponential_backoff(retries=5, base_wait_time=1)
def gpt4_request(self) -> Any:
response = self.executor.chat.completions.create(
model="gpt-4-turbo-preview",
messages=self.messenger.get_messages(),
max_tokens=300,
)
description = response.choices[0].message.content
return description
def ask_info(
self, query: str, text: Optional[str] = None, *args: Any, **kwargs: Any
) -> str:
if self.messenger.check_iter_round_num() == 0:
initial_message = "The text information for the previously described task is as follows: "
initial_message += (
text if text is not None else "No text provided."
)
initial_message += (
" Here is what you should do: " + self.task_instruction
)
self.messenger.add_user_message(initial_message)
description = self.gpt4_request()
return description
if __name__ == "__main__":
processor = GPT4Processor()
text = "Hugging Face has released a new version of Transformers that brings several enhancements."
summary: str = processor.ask_info(
query="Summarize the changes.", text=text
)
print(summary)