gistillery / src /ml.py
Benjamin Bossan
Initial commit
a240da9
import abc
import logging
import re
import httpx
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from base import JobInput
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
MODEL_NAME = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
class Summarizer:
def __init__(self) -> None:
self.template = "Summarize the text below in two sentences:\n\n{}"
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
self.generation_config.max_new_tokens = 200
self.generation_config.min_new_tokens = 100
self.generation_config.top_k = 5
self.generation_config.repetition_penalty = 1.5
def __call__(self, x: str) -> str:
text = self.template.format(x)
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, generation_config=self.generation_config)
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
assert isinstance(output, str)
return output
def get_name(self) -> str:
return f"Summarizer({MODEL_NAME})"
class Tagger:
def __init__(self) -> None:
self.template = (
"Create a list of tags for the text below. The tags should be high level "
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
)
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
self.generation_config.max_new_tokens = 50
self.generation_config.min_new_tokens = 25
# increase the temperature to make the model more creative
self.generation_config.temperature = 1.5
def _extract_tags(self, text: str) -> list[str]:
tags = set()
for tag in text.split():
if tag.startswith("#"):
tags.add(tag.lower())
return sorted(tags)
def __call__(self, x: str) -> list[str]:
text = self.template.format(x)
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, generation_config=self.generation_config)
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
tags = self._extract_tags(output)
return tags
def get_name(self) -> str:
return f"Tagger({MODEL_NAME})"
class Processor(abc.ABC):
def __call__(self, job: JobInput) -> str:
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
logger.info(f"Finished processing input (id={_id[:8]})")
return result
def process(self, input: JobInput) -> str:
raise NotImplementedError
def match(self, input: JobInput) -> bool:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
class RawProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content
def get_name(self) -> str:
return self.__class__.__name__
class PlainUrlProcessor(Processor):
def __init__(self) -> None:
self.client = httpx.Client()
self.regex = re.compile(r"(https?://[^\s]+)")
self.url = None
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
urls = list(self.regex.findall(input.content))
if len(urls) == 1:
self.url = urls[0]
return True
return False
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
assert isinstance(self.url, str)
text = self.client.get(self.url).text
assert isinstance(text, str)
text = self.template.format(url=self.url, content=text)
return text
def get_name(self) -> str:
return self.__class__.__name__
class ProcessorRegistry:
def __init__(self) -> None:
self.registry: list[Processor] = []
self.default_registry: list[Processor] = []
self.set_default_processors()
def set_default_processors(self) -> None:
self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
def register(self, processor: Processor) -> None:
self.registry.append(processor)
def dispatch(self, input: JobInput) -> Processor:
for processor in self.registry + self.default_registry:
if processor.match(input):
return processor
# should never be requires, but eh
return RawProcessor()