Spaces:
Runtime error
Runtime error
File size: 4,690 Bytes
a240da9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import abc
import logging
import re
import httpx
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from base import JobInput
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
MODEL_NAME = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
class Summarizer:
def __init__(self) -> None:
self.template = "Summarize the text below in two sentences:\n\n{}"
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
self.generation_config.max_new_tokens = 200
self.generation_config.min_new_tokens = 100
self.generation_config.top_k = 5
self.generation_config.repetition_penalty = 1.5
def __call__(self, x: str) -> str:
text = self.template.format(x)
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, generation_config=self.generation_config)
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
assert isinstance(output, str)
return output
def get_name(self) -> str:
return f"Summarizer({MODEL_NAME})"
class Tagger:
def __init__(self) -> None:
self.template = (
"Create a list of tags for the text below. The tags should be high level "
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
)
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
self.generation_config.max_new_tokens = 50
self.generation_config.min_new_tokens = 25
# increase the temperature to make the model more creative
self.generation_config.temperature = 1.5
def _extract_tags(self, text: str) -> list[str]:
tags = set()
for tag in text.split():
if tag.startswith("#"):
tags.add(tag.lower())
return sorted(tags)
def __call__(self, x: str) -> list[str]:
text = self.template.format(x)
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, generation_config=self.generation_config)
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
tags = self._extract_tags(output)
return tags
def get_name(self) -> str:
return f"Tagger({MODEL_NAME})"
class Processor(abc.ABC):
def __call__(self, job: JobInput) -> str:
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
logger.info(f"Finished processing input (id={_id[:8]})")
return result
def process(self, input: JobInput) -> str:
raise NotImplementedError
def match(self, input: JobInput) -> bool:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
class RawProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content
def get_name(self) -> str:
return self.__class__.__name__
class PlainUrlProcessor(Processor):
def __init__(self) -> None:
self.client = httpx.Client()
self.regex = re.compile(r"(https?://[^\s]+)")
self.url = None
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
urls = list(self.regex.findall(input.content))
if len(urls) == 1:
self.url = urls[0]
return True
return False
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
assert isinstance(self.url, str)
text = self.client.get(self.url).text
assert isinstance(text, str)
text = self.template.format(url=self.url, content=text)
return text
def get_name(self) -> str:
return self.__class__.__name__
class ProcessorRegistry:
def __init__(self) -> None:
self.registry: list[Processor] = []
self.default_registry: list[Processor] = []
self.set_default_processors()
def set_default_processors(self) -> None:
self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
def register(self, processor: Processor) -> None:
self.registry.append(processor)
def dispatch(self, input: JobInput) -> Processor:
for processor in self.registry + self.default_registry:
if processor.match(input):
return processor
# should never be requires, but eh
return RawProcessor()
|