import abc import io import logging import re from typing import Optional import torch import trafilatura import urllib3 from httpx import Client from PIL import Image from transformers import AutoProcessor, WhisperForConditionalGeneration from gistillery.base import JobInput from gistillery.config import get_config from gistillery.errors import ProcessingError from gistillery.media import download_yt_audio, load_audio from gistillery.tools import get_agent logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) RE_URL = re.compile(r"(https?://[^\s]+)") def get_url(text: str) -> str | None: urls: list[str] = list(RE_URL.findall(text)) if len(urls) == 1: url = urls[0] return url return None class Processor(abc.ABC): def __init__(self) -> None: self.max_length = get_config().processing_max_length self._super_init_called = True def get_name(self) -> str: return self.__class__.__name__ def __call__(self, job: JobInput) -> str: if not self._super_init_called: raise RuntimeError( "super().__init__() was not called with class " f"{self.__class__.__name__}" ) _id = job.id logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") result = self.process(job) if len(result) > self.max_length: logger.warning( f"Length of result ({len(result)}) exceeds max_length " f"({self.max_length}), truncating" ) result = result[: self.max_length] logger.info(f"Finished processing input (id={_id[:8]})") return result @abc.abstractmethod def process(self, input: JobInput) -> str: raise NotImplementedError @abc.abstractmethod def match(self, input: JobInput) -> bool: raise NotImplementedError class RawTextProcessor(Processor): def match(self, input: JobInput) -> bool: return True def process(self, input: JobInput) -> str: return input.content.strip() class DefaultUrlProcessor(Processor): def __init__(self) -> None: super().__init__() self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False self.url = url return True def process(self, input: JobInput) -> str: """Get content of website and return it as string""" if not isinstance(self.url, str): raise TypeError("self.url must be a string") text = self.client.get(self.url).text assert isinstance(text, str) extracted = trafilatura.extract(text) text = self.template.format(url=self.url, content=extracted) return str(text) class PdfUrlProcessor(Processor): def __init__(self) -> None: super().__init__() self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" self.stop_words = get_config().pdf_stop_words def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False suffix = url.rsplit(".", 1)[-1].lower() if suffix != "pdf": return False self.url = url return True def process(self, input: JobInput) -> str: if not isinstance(self.url, str): raise TypeError("self.url must be a string") response = self.client.get(self.url) import pypdf pdf = pypdf.PdfReader(io.BytesIO(response.content)) results = [] for page in pdf.pages: results.append(page.extract_text()) if any(word in results[-1] for word in self.stop_words): break text = "\n".join(results).strip() if not text: raise ProcessingError("No text could be extracted from PDF") return self.template.format(url=self.url, content=text) class ImageUrlProcessor(Processor): def __init__(self) -> None: super().__init__() self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'} def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False suffix = url.rsplit(".", 1)[-1].lower() if suffix not in self.image_suffixes: return False self.url = url return True def process(self, input: JobInput) -> str: if not isinstance(self.url, str): raise TypeError("self.url must be a string") response = self.client.get(self.url) image = Image.open(io.BytesIO(response.content)).convert('RGB') caption = get_agent().run("Caption the following image", image=image) text = str(caption) return self.template.format(url=self.url, content=text) class YoutubeUrlProcessor(Processor): """Download yt audio, transcribe with whisper""" def __init__(self) -> None: super().__init__() self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en") self.model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-small.en" ) self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"} def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False parsed = urllib3.util.parse_url(url) if parsed.host not in self.hosts: return False self.url = url return True @staticmethod def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor: """Create batches from last dimension, pad last batch if necessary Examples >>> import torch >>> x = torch.zeros((1, 10, 213)) >>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape torch.Size([3, 10, 100]) """ # ugly workaround, transformers whisper implementation requires a # specific shape of input length, probably there is a better way... batches = input_ids.split(max_len, dim=-1) # type: ignore last = batches[-1] n = last.shape[-1] last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0) batches = batches[:-1] + (last,) return torch.concat(batches) def process(self, input: JobInput) -> str: if not isinstance(self.url, str): raise TypeError("self.url must be a string") config = get_config() fname = download_yt_audio(self.url, max_length=config.max_yt_length) audio = load_audio(fname, sampling_rate=config.sampling_rate) inputs = self.processor( audio, return_tensors='pt', sampling_rate=config.sampling_rate, max_length=-1, ) batch = self.make_batch( inputs['input_features'], max_len=2 * self.model.config.max_source_positions ) generated_ids = self.model.generate(batch) transcription = self.processor.batch_decode( generated_ids, skip_special_tokens=True ) return self.template.format(url=self.url, content=" ".join(transcription))