Spaces:
Runtime error
Runtime error
import abc | |
import io | |
import logging | |
import re | |
from typing import Optional | |
import torch | |
import trafilatura | |
import urllib3 | |
from httpx import Client | |
from PIL import Image | |
from transformers import AutoProcessor, WhisperForConditionalGeneration | |
from gistillery.base import JobInput | |
from gistillery.config import get_config | |
from gistillery.errors import ProcessingError | |
from gistillery.media import download_yt_audio, load_audio | |
from gistillery.tools import get_agent | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
RE_URL = re.compile(r"(https?://[^\s]+)") | |
def get_url(text: str) -> str | None: | |
urls: list[str] = list(RE_URL.findall(text)) | |
if len(urls) == 1: | |
url = urls[0] | |
return url | |
return None | |
class Processor(abc.ABC): | |
def __init__(self) -> None: | |
self.max_length = get_config().processing_max_length | |
self._super_init_called = True | |
def get_name(self) -> str: | |
return self.__class__.__name__ | |
def __call__(self, job: JobInput) -> str: | |
if not self._super_init_called: | |
raise RuntimeError( | |
"super().__init__() was not called with class " | |
f"{self.__class__.__name__}" | |
) | |
_id = job.id | |
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") | |
result = self.process(job) | |
if len(result) > self.max_length: | |
logger.warning( | |
f"Length of result ({len(result)}) exceeds max_length " | |
f"({self.max_length}), truncating" | |
) | |
result = result[: self.max_length] | |
logger.info(f"Finished processing input (id={_id[:8]})") | |
return result | |
def process(self, input: JobInput) -> str: | |
raise NotImplementedError | |
def match(self, input: JobInput) -> bool: | |
raise NotImplementedError | |
class RawTextProcessor(Processor): | |
def match(self, input: JobInput) -> bool: | |
return True | |
def process(self, input: JobInput) -> str: | |
return input.content.strip() | |
class DefaultUrlProcessor(Processor): | |
def __init__(self) -> None: | |
super().__init__() | |
self.client = Client() | |
self.url = Optional[str] | |
self.template = "{url}\n\n{content}" | |
def match(self, input: JobInput) -> bool: | |
url = get_url(input.content.strip()) | |
if url is None: | |
return False | |
self.url = url | |
return True | |
def process(self, input: JobInput) -> str: | |
"""Get content of website and return it as string""" | |
if not isinstance(self.url, str): | |
raise TypeError("self.url must be a string") | |
text = self.client.get(self.url).text | |
assert isinstance(text, str) | |
extracted = trafilatura.extract(text) | |
text = self.template.format(url=self.url, content=extracted) | |
return str(text) | |
class PdfUrlProcessor(Processor): | |
def __init__(self) -> None: | |
super().__init__() | |
self.client = Client() | |
self.url = Optional[str] | |
self.template = "{url}\n\n{content}" | |
self.stop_words = get_config().pdf_stop_words | |
def match(self, input: JobInput) -> bool: | |
url = get_url(input.content.strip()) | |
if url is None: | |
return False | |
suffix = url.rsplit(".", 1)[-1].lower() | |
if suffix != "pdf": | |
return False | |
self.url = url | |
return True | |
def process(self, input: JobInput) -> str: | |
if not isinstance(self.url, str): | |
raise TypeError("self.url must be a string") | |
response = self.client.get(self.url) | |
import pypdf | |
pdf = pypdf.PdfReader(io.BytesIO(response.content)) | |
results = [] | |
for page in pdf.pages: | |
results.append(page.extract_text()) | |
if any(word in results[-1] for word in self.stop_words): | |
break | |
text = "\n".join(results).strip() | |
if not text: | |
raise ProcessingError("No text could be extracted from PDF") | |
return self.template.format(url=self.url, content=text) | |
class ImageUrlProcessor(Processor): | |
def __init__(self) -> None: | |
super().__init__() | |
self.client = Client() | |
self.url = Optional[str] | |
self.template = "{url}\n\n{content}" | |
self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'} | |
def match(self, input: JobInput) -> bool: | |
url = get_url(input.content.strip()) | |
if url is None: | |
return False | |
suffix = url.rsplit(".", 1)[-1].lower() | |
if suffix not in self.image_suffixes: | |
return False | |
self.url = url | |
return True | |
def process(self, input: JobInput) -> str: | |
if not isinstance(self.url, str): | |
raise TypeError("self.url must be a string") | |
response = self.client.get(self.url) | |
image = Image.open(io.BytesIO(response.content)).convert('RGB') | |
caption = get_agent().run("Caption the following image", image=image) | |
text = str(caption) | |
return self.template.format(url=self.url, content=text) | |
class YoutubeUrlProcessor(Processor): | |
"""Download yt audio, transcribe with whisper""" | |
def __init__(self) -> None: | |
super().__init__() | |
self.client = Client() | |
self.url = Optional[str] | |
self.template = "{url}\n\n{content}" | |
self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en") | |
self.model = WhisperForConditionalGeneration.from_pretrained( | |
"openai/whisper-small.en" | |
) | |
self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"} | |
def match(self, input: JobInput) -> bool: | |
url = get_url(input.content.strip()) | |
if url is None: | |
return False | |
parsed = urllib3.util.parse_url(url) | |
if parsed.host not in self.hosts: | |
return False | |
self.url = url | |
return True | |
def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor: | |
"""Create batches from last dimension, pad last batch if necessary | |
Examples | |
>>> import torch | |
>>> x = torch.zeros((1, 10, 213)) | |
>>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape | |
torch.Size([3, 10, 100]) | |
""" | |
# ugly workaround, transformers whisper implementation requires a | |
# specific shape of input length, probably there is a better way... | |
batches = input_ids.split(max_len, dim=-1) # type: ignore | |
last = batches[-1] | |
n = last.shape[-1] | |
last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0) | |
batches = batches[:-1] + (last,) | |
return torch.concat(batches) | |
def process(self, input: JobInput) -> str: | |
if not isinstance(self.url, str): | |
raise TypeError("self.url must be a string") | |
config = get_config() | |
fname = download_yt_audio(self.url, max_length=config.max_yt_length) | |
audio = load_audio(fname, sampling_rate=config.sampling_rate) | |
inputs = self.processor( | |
audio, | |
return_tensors='pt', | |
sampling_rate=config.sampling_rate, | |
max_length=-1, | |
) | |
batch = self.make_batch( | |
inputs['input_features'], max_len=2 * self.model.config.max_source_positions | |
) | |
generated_ids = self.model.generate(batch) | |
transcription = self.processor.batch_decode( | |
generated_ids, skip_special_tokens=True | |
) | |
return self.template.format(url=self.url, content=" ".join(transcription)) | |