File size: 2,212 Bytes
084fe8e
acb3380
084fe8e
acb3380
084fe8e
acb3380
 
 
 
 
084fe8e
acb3380
084fe8e
 
 
 
acb3380
084fe8e
 
 
 
 
acb3380
084fe8e
 
 
 
 
 
 
acb3380
 
084fe8e
 
 
 
 
 
 
acb3380
084fe8e
 
 
 
 
 
 
acb3380
 
 
084fe8e
acb3380
084fe8e
acb3380
 
 
 
 
 
084fe8e
acb3380
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
import os
from typing import Any, Dict, Optional

from huggingface_hub import InferenceClient

from ctm.messengers.messenger_base import BaseMessenger
from ctm.processors.processor_base import BaseProcessor


@BaseProcessor.register_processor("bart_text_summary_processor")
class BartTextSummaryProcessor(BaseProcessor):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(
            *args, **kwargs
        )  # Ensure base class is properly initialized

    def init_executor(self) -> None:
        hf_token = os.getenv("HF_TOKEN")
        if not hf_token:
            raise ValueError("HF_TOKEN environment variable is not set")
        self.executor = InferenceClient(token=hf_token)

    def init_messenger(self) -> None:
        self.messenger = BaseMessenger("bart_text_summ_messenger")

    def init_task_info(self) -> None:
        pass

    def update_info(self, feedback: str) -> None:
        self.messenger.add_assistant_message(feedback)

    def ask_info(
        self, text: Optional[str] = None, *args: Any, **kwargs: Any
    ) -> str | Any:
        if text is None:
            raise ValueError("Context must not be None")
        if self.messenger.check_iter_round_num() == 0:
            self.messenger.add_user_message(text)

        response: Dict[str, Any] = json.loads(
            self.executor.post(
                json={"inputs": self.messenger.get_messages()},
                model="facebook/bart-large-cnn",
            )
        )[0]
        return response["summary_text"]


if __name__ == "__main__":
    processor = BartTextSummaryProcessor()
    image_path = "../ctmai-test1.png"
    text = (
        "In a shocking turn of events, Hugging Face has released a new version of Transformers "
        "that brings several enhancements and bug fixes. Users are thrilled with the improvements "
        "and are finding the new version to be significantly better than the previous one. "
        "The Hugging Face team is thankful for the community's support and continues to work "
        "towards making the library the best it can be."
    )
    summary = processor.ask_info(context=text, image_path=image_path)
    print(summary)