import abc import io import logging import re from typing import Optional import torch import trafilatura import urllib3 from httpx import Client from PIL import Image from transformers import AutoProcessor, WhisperForConditionalGeneration from gistillery.base import JobInput from gistillery.config import get_config from gistillery.media import download_yt_audio, load_audio from gistillery.tools import get_agent logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) RE_URL = re.compile(r"(https?://[^\s]+)") def get_url(text: str) -> str | None: urls: list[str] = list(RE_URL.findall(text)) if len(urls) == 1: url = urls[0] return url return None class Processor(abc.ABC): def get_name(self) -> str: return self.__class__.__name__ def __call__(self, job: JobInput) -> str: _id = job.id logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") result = self.process(job) logger.info(f"Finished processing input (id={_id[:8]})") return result @abc.abstractmethod def process(self, input: JobInput) -> str: raise NotImplementedError @abc.abstractmethod def match(self, input: JobInput) -> bool: raise NotImplementedError class RawTextProcessor(Processor): def match(self, input: JobInput) -> bool: return True def process(self, input: JobInput) -> str: return input.content.strip() class DefaultUrlProcessor(Processor): def __init__(self) -> None: self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False self.url = url return True def process(self, input: JobInput) -> str: """Get content of website and return it as string""" if not isinstance(self.url, str): raise TypeError("self.url must be a string") text = self.client.get(self.url).text assert isinstance(text, str) extracted = trafilatura.extract(text) text = self.template.format(url=self.url, content=extracted) return str(text) class ImageUrlProcessor(Processor): def __init__(self) -> None: self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'} def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False suffix = url.rsplit(".", 1)[-1].lower() if suffix not in self.image_suffixes: return False self.url = url return True def process(self, input: JobInput) -> str: if not isinstance(self.url, str): raise TypeError("self.url must be a string") response = self.client.get(self.url) image = Image.open(io.BytesIO(response.content)).convert('RGB') caption = get_agent().run("Caption the following image", image=image) return str(caption) class YoutubeUrlProcessor(Processor): """Download yt audio, transcribe with whisper""" def __init__(self) -> None: self.client = Client() self.url = Optional[str] self.template = "{url}\n\n{content}" self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en") self.model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-small.en" ) self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"} def match(self, input: JobInput) -> bool: url = get_url(input.content.strip()) if url is None: return False parsed = urllib3.util.parse_url(url) if parsed.host not in self.hosts: return False self.url = url return True @staticmethod def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor: """Create batches from last dimension, pad last batch if necessary Examples >>> import torch >>> x = torch.zeros((1, 10, 213)) >>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape torch.Size([3, 10, 100]) """ # ugly workaround, transformers whisper implementation requires a # specific shape of input length, probably there is a better way... batches = input_ids.split(max_len, dim=-1) # type: ignore last = batches[-1] n = last.shape[-1] last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0) batches = batches[:-1] + (last,) return torch.concat(batches) def process(self, input: JobInput) -> str: if not isinstance(self.url, str): raise TypeError("self.url must be a string") config = get_config() fname = download_yt_audio(self.url, max_length=config.max_yt_length) audio = load_audio(fname, sampling_rate=config.sampling_rate) inputs = self.processor( audio, return_tensors='pt', sampling_rate=config.sampling_rate, max_length=-1, ) batch = self.make_batch( inputs['input_features'], max_len=2 * self.model.config.max_source_positions ) generated_ids = self.model.generate(batch) transcription = self.processor.batch_decode( generated_ids, skip_special_tokens=True ) return self.template.format(url=self.url, content=" ".join(transcription))