Spaces:
Runtime error
Runtime error
import abc | |
from huggingface_hub import login | |
from transformers.tools import TextSummarizationTool | |
from transformers import HfAgent | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
from gistillery.config import get_config | |
agent = None | |
def get_agent() -> HfAgent: | |
global agent | |
if agent is None: | |
login(get_config().hf_hub_token) | |
agent = HfAgent(get_config().hf_agent) | |
return agent | |
class Summarizer(abc.ABC): | |
def get_name(self) -> str: | |
raise NotImplementedError | |
def __call__(self, x: str) -> str: | |
raise NotImplementedError | |
class HfDefaultSummarizer(Summarizer): | |
def __init__(self) -> None: | |
self.summarizer = TextSummarizationTool() | |
def get_name(self) -> str: | |
return "hf_default" | |
def __call__(self, x: str) -> str: | |
summary = self.summarizer(x) | |
assert isinstance(summary, str) | |
return summary | |
class Tagger(abc.ABC): | |
def get_name(self) -> str: | |
raise NotImplementedError | |
def __call__(self, x: str) -> list[str]: | |
raise NotImplementedError | |
class HfDefaultTagger(Tagger): | |
def __init__(self, model_name: str = "google/flan-t5-large") -> None: | |
self.model_name = model_name | |
config = GenerationConfig.from_pretrained(self.model_name) | |
config.max_new_tokens = 50 | |
config.min_new_tokens = 25 | |
# increase the temperature to make the model more creative | |
config.temperature = 1.5 | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.generation_config = config | |
self.template = ( | |
"Create a list of tags for the text below. The tags should be high level " | |
"and specific. Return the results as a comma separated list.\n\n" | |
"{}\n\nTags:\n" | |
) | |
def _extract_tags(self, text: str) -> list[str]: | |
tags = {"#general"} | |
for tag in text.split(","): | |
tag = tag.strip().lower().replace(" ", "") | |
if not tag.startswith("#"): | |
tag = "#" + tag | |
tags.add(tag) | |
return sorted(tags) | |
def __call__(self, x: str) -> list[str]: | |
# should not return duplicates | |
text = self.template.format(x) | |
inputs = self.tokenizer(text, return_tensors="pt") | |
outputs = self.model.generate( | |
**inputs, generation_config=self.generation_config | |
) | |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
tags = self._extract_tags(output) | |
return tags | |
def get_name(self) -> str: | |
return f"{self.__class__.__name__}({self.model_name})" | |