import abc from typing import Any import logging import re import httpx from gistillery.base import JobInput logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) class Processor(abc.ABC): def get_name(self) -> str: return self.__class__.__name__ def __call__(self, job: JobInput) -> str: _id = job.id logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") result = self.process(job) logger.info(f"Finished processing input (id={_id[:8]})") return result @abc.abstractmethod def process(self, input: JobInput) -> str: raise NotImplementedError @abc.abstractmethod def match(self, input: JobInput) -> bool: raise NotImplementedError class Summarizer(abc.ABC): def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: raise NotImplementedError def get_name(self) -> str: raise NotImplementedError @abc.abstractmethod def __call__(self, x: str) -> str: raise NotImplementedError class Tagger(abc.ABC): def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: raise NotImplementedError def get_name(self) -> str: raise NotImplementedError @abc.abstractmethod def __call__(self, x: str) -> list[str]: raise NotImplementedError class MlRegistry: def __init__(self) -> None: self.processors: list[Processor] = [] self.summerizer: Summarizer | None = None self.tagger: Tagger | None = None self.model = None self.tokenizer = None def register_processor(self, processor: Processor) -> None: self.processors.append(processor) def register_summarizer(self, summarizer: Summarizer) -> None: self.summerizer = summarizer def register_tagger(self, tagger: Tagger) -> None: self.tagger = tagger def get_processor(self, input: JobInput) -> Processor: assert self.processors for processor in self.processors: if processor.match(input): return processor return RawTextProcessor() def get_summarizer(self) -> Summarizer: assert self.summerizer return self.summerizer def get_tagger(self) -> Tagger: assert self.tagger return self.tagger class HfTransformersSummarizer(Summarizer): def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: self.model_name = model_name self.model = model self.tokenizer = tokenizer self.generation_config = generation_config self.template = "Summarize the text below in two sentences:\n\n{}" def __call__(self, x: str) -> str: text = self.template.format(x) inputs = self.tokenizer(text, return_tensors="pt") outputs = self.model.generate(**inputs, generation_config=self.generation_config) output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] assert isinstance(output, str) return output def get_name(self) -> str: return f"{self.__class__.__name__}({self.model_name})" class HfTransformersTagger(Tagger): def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: self.model_name = model_name self.model = model self.tokenizer = tokenizer self.generation_config = generation_config self.template = ( "Create a list of tags for the text below. The tags should be high level " "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general" ) def _extract_tags(self, text: str) -> list[str]: tags = set() for tag in text.split(): if tag.startswith("#"): tags.add(tag.lower()) return sorted(tags) def __call__(self, x: str) -> list[str]: text = self.template.format(x) inputs = self.tokenizer(text, return_tensors="pt") outputs = self.model.generate(**inputs, generation_config=self.generation_config) output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] tags = self._extract_tags(output) return tags def get_name(self) -> str: return f"{self.__class__.__name__}({self.model_name})" class RawTextProcessor(Processor): def match(self, input: JobInput) -> bool: return True def process(self, input: JobInput) -> str: return input.content class DefaultUrlProcessor(Processor): def __init__(self) -> None: self.client = httpx.Client() self.regex = re.compile(r"(https?://[^\s]+)") self.url = None self.template = "{url}\n\n{content}" def match(self, input: JobInput) -> bool: urls = list(self.regex.findall(input.content)) if len(urls) == 1: self.url = urls[0] return True return False def process(self, input: JobInput) -> str: """Get content of website and return it as string""" assert isinstance(self.url, str) text = self.client.get(self.url).text assert isinstance(text, str) text = self.template.format(url=self.url, content=text) return text # class ProcessorRegistry: # def __init__(self) -> None: # self.registry: list[Processor] = [] # self.default_registry: list[Processor] = [] # self.set_default_processors() # def set_default_processors(self) -> None: # self.default_registry.extend([PlainUrlProcessor(), RawProcessor()]) # def register(self, processor: Processor) -> None: # self.registry.append(processor) # def dispatch(self, input: JobInput) -> Processor: # for processor in self.registry + self.default_registry: # if processor.match(input): # return processor # # should never be requires, but eh # return RawProcessor()