gistillery / src /gistillery /preprocessing.py
Benjamin Bossan
Add pdf processor using pypdf
4c2b75c
import abc
import io
import logging
import re
from typing import Optional
import torch
import trafilatura
import urllib3
from httpx import Client
from PIL import Image
from transformers import AutoProcessor, WhisperForConditionalGeneration
from gistillery.base import JobInput
from gistillery.config import get_config
from gistillery.errors import ProcessingError
from gistillery.media import download_yt_audio, load_audio
from gistillery.tools import get_agent
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
RE_URL = re.compile(r"(https?://[^\s]+)")
def get_url(text: str) -> str | None:
urls: list[str] = list(RE_URL.findall(text))
if len(urls) == 1:
url = urls[0]
return url
return None
class Processor(abc.ABC):
def __init__(self) -> None:
self.max_length = get_config().processing_max_length
self._super_init_called = True
def get_name(self) -> str:
return self.__class__.__name__
def __call__(self, job: JobInput) -> str:
if not self._super_init_called:
raise RuntimeError(
"super().__init__() was not called with class "
f"{self.__class__.__name__}"
)
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
if len(result) > self.max_length:
logger.warning(
f"Length of result ({len(result)}) exceeds max_length "
f"({self.max_length}), truncating"
)
result = result[: self.max_length]
logger.info(f"Finished processing input (id={_id[:8]})")
return result
@abc.abstractmethod
def process(self, input: JobInput) -> str:
raise NotImplementedError
@abc.abstractmethod
def match(self, input: JobInput) -> bool:
raise NotImplementedError
class RawTextProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content.strip()
class DefaultUrlProcessor(Processor):
def __init__(self) -> None:
super().__init__()
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
self.url = url
return True
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
text = self.client.get(self.url).text
assert isinstance(text, str)
extracted = trafilatura.extract(text)
text = self.template.format(url=self.url, content=extracted)
return str(text)
class PdfUrlProcessor(Processor):
def __init__(self) -> None:
super().__init__()
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
self.stop_words = get_config().pdf_stop_words
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
suffix = url.rsplit(".", 1)[-1].lower()
if suffix != "pdf":
return False
self.url = url
return True
def process(self, input: JobInput) -> str:
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
response = self.client.get(self.url)
import pypdf
pdf = pypdf.PdfReader(io.BytesIO(response.content))
results = []
for page in pdf.pages:
results.append(page.extract_text())
if any(word in results[-1] for word in self.stop_words):
break
text = "\n".join(results).strip()
if not text:
raise ProcessingError("No text could be extracted from PDF")
return self.template.format(url=self.url, content=text)
class ImageUrlProcessor(Processor):
def __init__(self) -> None:
super().__init__()
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'}
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
suffix = url.rsplit(".", 1)[-1].lower()
if suffix not in self.image_suffixes:
return False
self.url = url
return True
def process(self, input: JobInput) -> str:
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
response = self.client.get(self.url)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
caption = get_agent().run("Caption the following image", image=image)
text = str(caption)
return self.template.format(url=self.url, content=text)
class YoutubeUrlProcessor(Processor):
"""Download yt audio, transcribe with whisper"""
def __init__(self) -> None:
super().__init__()
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
self.model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-small.en"
)
self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"}
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
parsed = urllib3.util.parse_url(url)
if parsed.host not in self.hosts:
return False
self.url = url
return True
@staticmethod
def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor:
"""Create batches from last dimension, pad last batch if necessary
Examples
>>> import torch
>>> x = torch.zeros((1, 10, 213))
>>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape
torch.Size([3, 10, 100])
"""
# ugly workaround, transformers whisper implementation requires a
# specific shape of input length, probably there is a better way...
batches = input_ids.split(max_len, dim=-1) # type: ignore
last = batches[-1]
n = last.shape[-1]
last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0)
batches = batches[:-1] + (last,)
return torch.concat(batches)
def process(self, input: JobInput) -> str:
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
config = get_config()
fname = download_yt_audio(self.url, max_length=config.max_yt_length)
audio = load_audio(fname, sampling_rate=config.sampling_rate)
inputs = self.processor(
audio,
return_tensors='pt',
sampling_rate=config.sampling_rate,
max_length=-1,
)
batch = self.make_batch(
inputs['input_features'], max_len=2 * self.model.config.max_source_positions
)
generated_ids = self.model.generate(batch)
transcription = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
return self.template.format(url=self.url, content=" ".join(transcription))