import abc from huggingface_hub import login from import TextSummarizationTool from transformers import HfAgent from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig from gistillery.config import get_config agent = None def get_agent() -> HfAgent: global agent if agent is None: login(get_config().hf_hub_token) agent = HfAgent(get_config().hf_agent) return agent class Summarizer(abc.ABC): @abc.abstractmethod def get_name(self) -> str: raise NotImplementedError @abc.abstractmethod def __call__(self, x: str) -> str: raise NotImplementedError class HfDefaultSummarizer(Summarizer): def __init__(self) -> None: self.summarizer = TextSummarizationTool() def get_name(self) -> str: return "hf_default" def __call__(self, x: str) -> str: summary = self.summarizer(x) assert isinstance(summary, str) return summary class Tagger(abc.ABC): @abc.abstractmethod def get_name(self) -> str: raise NotImplementedError @abc.abstractmethod def __call__(self, x: str) -> list[str]: raise NotImplementedError class HfDefaultTagger(Tagger): def __init__(self, model_name: str = "google/flan-t5-large") -> None: self.model_name = model_name config = GenerationConfig.from_pretrained(self.model_name) config.max_new_tokens = 50 config.min_new_tokens = 25 # increase the temperature to make the model more creative config.temperature = 1.5 self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.generation_config = config self.template = ( "Create a list of tags for the text below. The tags should be high level " "and specific. Return the results as a comma separated list.\n\n" "{}\n\nTags:\n" ) def _extract_tags(self, text: str) -> list[str]: tags = {"#general"} for tag in text.split(","): tag = tag.strip().lower().replace(" ", "") if not tag.startswith("#"): tag = "#" + tag tags.add(tag) return sorted(tags) def __call__(self, x: str) -> list[str]: # should not return duplicates text = self.template.format(x) inputs = self.tokenizer(text, return_tensors="pt") outputs = self.model.generate( **inputs, generation_config=self.generation_config ) output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] tags = self._extract_tags(output) return tags def get_name(self) -> str: return f"{self.__class__.__name__}({self.model_name})"