Spaces:
Runtime error
Runtime error
import abc | |
import logging | |
import re | |
import httpx | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
from base import JobInput | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
MODEL_NAME = "google/flan-t5-large" | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
class Summarizer: | |
def __init__(self) -> None: | |
self.template = "Summarize the text below in two sentences:\n\n{}" | |
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
self.generation_config.max_new_tokens = 200 | |
self.generation_config.min_new_tokens = 100 | |
self.generation_config.top_k = 5 | |
self.generation_config.repetition_penalty = 1.5 | |
def __call__(self, x: str) -> str: | |
text = self.template.format(x) | |
inputs = tokenizer(text, return_tensors="pt") | |
outputs = model.generate(**inputs, generation_config=self.generation_config) | |
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
assert isinstance(output, str) | |
return output | |
def get_name(self) -> str: | |
return f"Summarizer({MODEL_NAME})" | |
class Tagger: | |
def __init__(self) -> None: | |
self.template = ( | |
"Create a list of tags for the text below. The tags should be high level " | |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general" | |
) | |
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
self.generation_config.max_new_tokens = 50 | |
self.generation_config.min_new_tokens = 25 | |
# increase the temperature to make the model more creative | |
self.generation_config.temperature = 1.5 | |
def _extract_tags(self, text: str) -> list[str]: | |
tags = set() | |
for tag in text.split(): | |
if tag.startswith("#"): | |
tags.add(tag.lower()) | |
return sorted(tags) | |
def __call__(self, x: str) -> list[str]: | |
text = self.template.format(x) | |
inputs = tokenizer(text, return_tensors="pt") | |
outputs = model.generate(**inputs, generation_config=self.generation_config) | |
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
tags = self._extract_tags(output) | |
return tags | |
def get_name(self) -> str: | |
return f"Tagger({MODEL_NAME})" | |
class Processor(abc.ABC): | |
def __call__(self, job: JobInput) -> str: | |
_id = job.id | |
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") | |
result = self.process(job) | |
logger.info(f"Finished processing input (id={_id[:8]})") | |
return result | |
def process(self, input: JobInput) -> str: | |
raise NotImplementedError | |
def match(self, input: JobInput) -> bool: | |
raise NotImplementedError | |
def get_name(self) -> str: | |
raise NotImplementedError | |
class RawProcessor(Processor): | |
def match(self, input: JobInput) -> bool: | |
return True | |
def process(self, input: JobInput) -> str: | |
return input.content | |
def get_name(self) -> str: | |
return self.__class__.__name__ | |
class PlainUrlProcessor(Processor): | |
def __init__(self) -> None: | |
self.client = httpx.Client() | |
self.regex = re.compile(r"(https?://[^\s]+)") | |
self.url = None | |
self.template = "{url}\n\n{content}" | |
def match(self, input: JobInput) -> bool: | |
urls = list(self.regex.findall(input.content)) | |
if len(urls) == 1: | |
self.url = urls[0] | |
return True | |
return False | |
def process(self, input: JobInput) -> str: | |
"""Get content of website and return it as string""" | |
assert isinstance(self.url, str) | |
text = self.client.get(self.url).text | |
assert isinstance(text, str) | |
text = self.template.format(url=self.url, content=text) | |
return text | |
def get_name(self) -> str: | |
return self.__class__.__name__ | |
class ProcessorRegistry: | |
def __init__(self) -> None: | |
self.registry: list[Processor] = [] | |
self.default_registry: list[Processor] = [] | |
self.set_default_processors() | |
def set_default_processors(self) -> None: | |
self.default_registry.extend([PlainUrlProcessor(), RawProcessor()]) | |
def register(self, processor: Processor) -> None: | |
self.registry.append(processor) | |
def dispatch(self, input: JobInput) -> Processor: | |
for processor in self.registry + self.default_registry: | |
if processor.match(input): | |
return processor | |
# should never be requires, but eh | |
return RawProcessor() | |