gistillery / src /gistillery /registry.py
Benjamin Bossan
Use transformers agents where applicable
01ae0bb
raw
history blame
1.85 kB
from gistillery.base import JobInput
from gistillery.tools import Summarizer, Tagger, HfDefaultSummarizer, HfDefaultTagger
from gistillery.preprocessing import (
Processor,
RawTextProcessor,
ImageUrlProcessor,
DefaultUrlProcessor,
)
class ToolRegistry:
def __init__(self) -> None:
self.processors: list[Processor] = []
self.summerizer: Summarizer | None = None
self.tagger: Tagger | None = None
self.model = None
self.tokenizer = None
def register_processor(self, processor: Processor, last: bool = True) -> None:
if last:
self.processors.append(processor)
else:
self.processors.insert(0, processor)
def register_summarizer(self, summarizer: Summarizer) -> None:
self.summerizer = summarizer
def register_tagger(self, tagger: Tagger) -> None:
self.tagger = tagger
def get_processor(self, input: JobInput) -> Processor:
assert self.processors
for processor in self.processors:
if processor.match(input):
return processor
return RawTextProcessor()
def get_summarizer(self) -> Summarizer:
assert self.summerizer
return self.summerizer
def get_tagger(self) -> Tagger:
assert self.tagger
return self.tagger
_registry = None
def get_tool_registry() -> ToolRegistry:
global _registry
if _registry is not None:
return _registry
summarizer = HfDefaultSummarizer()
tagger = HfDefaultTagger()
_registry = ToolRegistry()
_registry.register_processor(ImageUrlProcessor())
_registry.register_processor(DefaultUrlProcessor())
_registry.register_processor(RawTextProcessor())
_registry.register_summarizer(summarizer)
_registry.register_tagger(tagger)
return _registry