import abc import logging import re import httpx from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig from base import JobInput logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) MODEL_NAME = "google/flan-t5-large" model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) class Summarizer: def __init__(self) -> None: self.template = "Summarize the text below in two sentences:\n\n{}" self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) self.generation_config.max_new_tokens = 200 self.generation_config.min_new_tokens = 100 self.generation_config.top_k = 5 self.generation_config.repetition_penalty = 1.5 def __call__(self, x: str) -> str: text = self.template.format(x) inputs = tokenizer(text, return_tensors="pt") outputs = model.generate(**inputs, generation_config=self.generation_config) output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] assert isinstance(output, str) return output def get_name(self) -> str: return f"Summarizer({MODEL_NAME})" class Tagger: def __init__(self) -> None: self.template = ( "Create a list of tags for the text below. The tags should be high level " "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general" ) self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) self.generation_config.max_new_tokens = 50 self.generation_config.min_new_tokens = 25 # increase the temperature to make the model more creative self.generation_config.temperature = 1.5 def _extract_tags(self, text: str) -> list[str]: tags = set() for tag in text.split(): if tag.startswith("#"): tags.add(tag.lower()) return sorted(tags) def __call__(self, x: str) -> list[str]: text = self.template.format(x) inputs = tokenizer(text, return_tensors="pt") outputs = model.generate(**inputs, generation_config=self.generation_config) output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] tags = self._extract_tags(output) return tags def get_name(self) -> str: return f"Tagger({MODEL_NAME})" class Processor(abc.ABC): def __call__(self, job: JobInput) -> str: _id = job.id logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") result = self.process(job) logger.info(f"Finished processing input (id={_id[:8]})") return result def process(self, input: JobInput) -> str: raise NotImplementedError def match(self, input: JobInput) -> bool: raise NotImplementedError def get_name(self) -> str: raise NotImplementedError class RawProcessor(Processor): def match(self, input: JobInput) -> bool: return True def process(self, input: JobInput) -> str: return input.content def get_name(self) -> str: return self.__class__.__name__ class PlainUrlProcessor(Processor): def __init__(self) -> None: self.client = httpx.Client() self.regex = re.compile(r"(https?://[^\s]+)") self.url = None self.template = "{url}\n\n{content}" def match(self, input: JobInput) -> bool: urls = list(self.regex.findall(input.content)) if len(urls) == 1: self.url = urls[0] return True return False def process(self, input: JobInput) -> str: """Get content of website and return it as string""" assert isinstance(self.url, str) text = self.client.get(self.url).text assert isinstance(text, str) text = self.template.format(url=self.url, content=text) return text def get_name(self) -> str: return self.__class__.__name__ class ProcessorRegistry: def __init__(self) -> None: self.registry: list[Processor] = [] self.default_registry: list[Processor] = [] self.set_default_processors() def set_default_processors(self) -> None: self.default_registry.extend([PlainUrlProcessor(), RawProcessor()]) def register(self, processor: Processor) -> None: self.registry.append(processor) def dispatch(self, input: JobInput) -> Processor: for processor in self.registry + self.default_registry: if processor.match(input): return processor # should never be requires, but eh return RawProcessor()