File size: 4,690 Bytes
a240da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import abc
import logging
import re

import httpx
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

from base import JobInput

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

MODEL_NAME = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


class Summarizer:
    def __init__(self) -> None:
        self.template = "Summarize the text below in two sentences:\n\n{}"
        self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
        self.generation_config.max_new_tokens = 200
        self.generation_config.min_new_tokens = 100
        self.generation_config.top_k = 5
        self.generation_config.repetition_penalty = 1.5

    def __call__(self, x: str) -> str:
        text = self.template.format(x)
        inputs = tokenizer(text, return_tensors="pt")
        outputs = model.generate(**inputs, generation_config=self.generation_config)
        output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        assert isinstance(output, str)
        return output

    def get_name(self) -> str:
        return f"Summarizer({MODEL_NAME})"


class Tagger:
    def __init__(self) -> None:
        self.template = (
            "Create a list of tags for the text below. The tags should be high level "
            "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
        )
        self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
        self.generation_config.max_new_tokens = 50
        self.generation_config.min_new_tokens = 25
        # increase the temperature to make the model more creative
        self.generation_config.temperature = 1.5

    def _extract_tags(self, text: str) -> list[str]:
        tags = set()
        for tag in text.split():
            if tag.startswith("#"):
                tags.add(tag.lower())
        return sorted(tags)

    def __call__(self, x: str) -> list[str]:
        text = self.template.format(x)
        inputs = tokenizer(text, return_tensors="pt")
        outputs = model.generate(**inputs, generation_config=self.generation_config)
        output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        tags = self._extract_tags(output)
        return tags

    def get_name(self) -> str:
        return f"Tagger({MODEL_NAME})"


class Processor(abc.ABC):
    def __call__(self, job: JobInput) -> str:
        _id = job.id
        logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
        result = self.process(job)
        logger.info(f"Finished processing input (id={_id[:8]})")
        return result

    def process(self, input: JobInput) -> str:
        raise NotImplementedError

    def match(self, input: JobInput) -> bool:
        raise NotImplementedError

    def get_name(self) -> str:
        raise NotImplementedError


class RawProcessor(Processor):
    def match(self, input: JobInput) -> bool:
        return True

    def process(self, input: JobInput) -> str:
        return input.content

    def get_name(self) -> str:
        return self.__class__.__name__


class PlainUrlProcessor(Processor):
    def __init__(self) -> None:
        self.client = httpx.Client()
        self.regex = re.compile(r"(https?://[^\s]+)")
        self.url = None
        self.template = "{url}\n\n{content}"

    def match(self, input: JobInput) -> bool:
        urls = list(self.regex.findall(input.content))
        if len(urls) == 1:
            self.url = urls[0]
            return True
        return False

    def process(self, input: JobInput) -> str:
        """Get content of website and return it as string"""
        assert isinstance(self.url, str)
        text = self.client.get(self.url).text
        assert isinstance(text, str)
        text = self.template.format(url=self.url, content=text)
        return text

    def get_name(self) -> str:
        return self.__class__.__name__


class ProcessorRegistry:
    def __init__(self) -> None:
        self.registry: list[Processor] = []
        self.default_registry: list[Processor] = []
        self.set_default_processors()

    def set_default_processors(self) -> None:
        self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])

    def register(self, processor: Processor) -> None:
        self.registry.append(processor)

    def dispatch(self, input: JobInput) -> Processor:
        for processor in self.registry + self.default_registry:
            if processor.match(input):
                return processor

        # should never be requires, but eh
        return RawProcessor()