Benjamin Bossan
Use transformers agents where applicable
01ae0bb
import abc
from huggingface_hub import login
from transformers.tools import TextSummarizationTool
from transformers import HfAgent
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from gistillery.config import get_config
agent = None
def get_agent() -> HfAgent:
global agent
if agent is None:
login(get_config().hf_hub_token)
agent = HfAgent(get_config().hf_agent)
return agent
class Summarizer(abc.ABC):
@abc.abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> str:
raise NotImplementedError
class HfDefaultSummarizer(Summarizer):
def __init__(self) -> None:
self.summarizer = TextSummarizationTool()
def get_name(self) -> str:
return "hf_default"
def __call__(self, x: str) -> str:
summary = self.summarizer(x)
assert isinstance(summary, str)
return summary
class Tagger(abc.ABC):
@abc.abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> list[str]:
raise NotImplementedError
class HfDefaultTagger(Tagger):
def __init__(self, model_name: str = "google/flan-t5-large") -> None:
self.model_name = model_name
config = GenerationConfig.from_pretrained(self.model_name)
config.max_new_tokens = 50
config.min_new_tokens = 25
# increase the temperature to make the model more creative
config.temperature = 1.5
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.generation_config = config
self.template = (
"Create a list of tags for the text below. The tags should be high level "
"and specific. Return the results as a comma separated list.\n\n"
"{}\n\nTags:\n"
)
def _extract_tags(self, text: str) -> list[str]:
tags = {"#general"}
for tag in text.split(","):
tag = tag.strip().lower().replace(" ", "")
if not tag.startswith("#"):
tag = "#" + tag
tags.add(tag)
return sorted(tags)
def __call__(self, x: str) -> list[str]:
# should not return duplicates
text = self.template.format(x)
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model.generate(
**inputs, generation_config=self.generation_config
)
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
tags = self._extract_tags(output)
return tags
def get_name(self) -> str:
return f"{self.__class__.__name__}({self.model_name})"