Benjamin Bossan commited on
Commit
2cbbc23
1 Parent(s): 308c6f6

Add youtube transcription processor

Browse files
Dockerfile CHANGED
@@ -1,6 +1,6 @@
1
  FROM pytorch/pytorch:latest
2
 
3
- RUN apt update && apt install -y && rm -rf /var/lib/apt/lists/*
4
 
5
  # Set up a new user named "user" with user ID 1000
6
  RUN useradd -m -u 1000 user
 
1
  FROM pytorch/pytorch:latest
2
 
3
+ RUN apt update && apt install -y && apt install ffmpeg && rm -rf /var/lib/apt/lists/*
4
 
5
  # Set up a new user named "user" with user ID 1000
6
  RUN useradd -m -u 1000 user
demo.py CHANGED
@@ -6,6 +6,8 @@ client = httpx.Client()
6
 
7
 
8
  def submit(inputs):
 
 
9
  payload = {"content": inputs, "author": "anna nymous"}
10
  httpx.post("http://localhost:8080/submit/", json=payload)
11
 
 
6
 
7
 
8
  def submit(inputs):
9
+ if not inputs:
10
+ return
11
  payload = {"content": inputs, "author": "anna nymous"}
12
  httpx.post("http://localhost:8080/submit/", json=payload)
13
 
pyproject.toml CHANGED
@@ -16,7 +16,8 @@ addopts = "--cov=src --cov-report=term-missing"
16
  [tool.mypy]
17
  no_implicit_optional = true
18
  strict = true
 
19
 
20
  [[tool.mypy.overrides]]
21
- module = "huggingface_hub,trafilatura,transformers.*"
22
  ignore_missing_imports = true
 
16
  [tool.mypy]
17
  no_implicit_optional = true
18
  strict = true
19
+ plugins = "numpy.typing.mypy_plugin"
20
 
21
  [[tool.mypy.overrides]]
22
+ module = "huggingface_hub,trafilatura,transformers.*,pytube"
23
  ignore_missing_imports = true
requirements-dev.txt CHANGED
@@ -5,3 +5,4 @@ ruff
5
  pytest
6
  pytest-cov
7
  types-Pillow
 
 
5
  pytest
6
  pytest-cov
7
  types-Pillow
8
+ types-urllib3
requirements.txt CHANGED
@@ -8,3 +8,5 @@ charset-normalizer
8
  trafilatura
9
  pillow
10
  gradio
 
 
 
8
  trafilatura
9
  pillow
10
  gradio
11
+ urllib3
12
+ pytube
src/gistillery/config.py CHANGED
@@ -8,6 +8,8 @@ class Config(BaseSettings):
8
  hf_hub_token: str = "missing"
9
  hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
10
  db_file_name: Path = Path("sqlite-data.db")
 
 
11
 
12
  class Config:
13
  # load .env file by default, with provisio to use other .env files if set
 
8
  hf_hub_token: str = "missing"
9
  hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
10
  db_file_name: Path = Path("sqlite-data.db")
11
+ sampling_rate: int = 16_000 # audio transcription
12
+ max_yt_length: int = 1800 # in minutes
13
 
14
  class Config:
15
  # load .env file by default, with provisio to use other .env files if set
src/gistillery/media.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import tempfile
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pytube
7
+
8
+
9
+ def download_yt_audio(url: str, max_length: int) -> str:
10
+ yt = pytube.YouTube(url)
11
+ if (max_length is not None) and (yt.length > max_length):
12
+ raise ValueError(f"Youtube video exceeds max length of {max_length}")
13
+
14
+ video = yt.streams.filter(only_audio=True).first()
15
+ tmp_path = tempfile.mkdtemp()
16
+ fname = video.download(output_path=tmp_path)
17
+ assert isinstance(fname, str)
18
+ return fname
19
+
20
+
21
+ def check_ffmpeg_installed() -> None:
22
+ cmd = ["ffmpeg", "-version"] # sic
23
+ try:
24
+ subprocess.run(cmd, check=True)
25
+ except FileNotFoundError as exc:
26
+ raise RuntimeError("This feature requires ffmpeg to be installed") from exc
27
+
28
+
29
+ # from openai whisper
30
+ def load_audio(file: str, sampling_rate: int) -> npt.NDArray[np.float32]:
31
+ """Open an audio file and read as mono waveform, resampling as necessary
32
+
33
+ Parameters
34
+ ----------
35
+ file: str
36
+ The audio file to open
37
+
38
+ sampling_rate: int
39
+ The sample rate to resample the audio if necessary
40
+
41
+ Returns
42
+ -------
43
+ A NumPy array containing the audio waveform, in float32 dtype.
44
+
45
+ """
46
+ check_ffmpeg_installed() # BB
47
+
48
+ # This launches a subprocess to decode audio while down-mixing
49
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
50
+ # fmt: off
51
+ cmd = [
52
+ "ffmpeg",
53
+ "-nostdin",
54
+ "-threads", "0",
55
+ "-i", file,
56
+ "-f", "s16le",
57
+ "-ac", "1",
58
+ "-acodec", "pcm_s16le",
59
+ "-ar", str(sampling_rate),
60
+ "-"
61
+ ]
62
+ # fmt: on
63
+ try:
64
+ out = subprocess.run(cmd, capture_output=True, check=True).stdout
65
+ except subprocess.CalledProcessError as e:
66
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
67
+
68
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/gistillery/preprocessing.py CHANGED
@@ -4,15 +4,18 @@ import logging
4
  import re
5
  from typing import Optional
6
 
 
7
  import trafilatura
 
8
  from httpx import Client
9
-
10
  from PIL import Image
 
11
 
12
  from gistillery.base import JobInput
 
 
13
  from gistillery.tools import get_agent
14
 
15
-
16
  logger = logging.getLogger(__name__)
17
  logger.setLevel(logging.DEBUG)
18
 
@@ -109,3 +112,73 @@ class ImageUrlProcessor(Processor):
109
  image = Image.open(io.BytesIO(response.content)).convert('RGB')
110
  caption = get_agent().run("Caption the following image", image=image)
111
  return str(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
  from typing import Optional
6
 
7
+ import torch
8
  import trafilatura
9
+ import urllib3
10
  from httpx import Client
 
11
  from PIL import Image
12
+ from transformers import AutoProcessor, WhisperForConditionalGeneration
13
 
14
  from gistillery.base import JobInput
15
+ from gistillery.config import get_config
16
+ from gistillery.media import download_yt_audio, load_audio
17
  from gistillery.tools import get_agent
18
 
 
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.DEBUG)
21
 
 
112
  image = Image.open(io.BytesIO(response.content)).convert('RGB')
113
  caption = get_agent().run("Caption the following image", image=image)
114
  return str(caption)
115
+
116
+
117
+ class YoutubeUrlProcessor(Processor):
118
+ """Download yt audio, transcribe with whisper"""
119
+
120
+ def __init__(self) -> None:
121
+ self.client = Client()
122
+ self.url = Optional[str]
123
+ self.template = "{url}\n\n{content}"
124
+
125
+ self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
126
+ self.model = WhisperForConditionalGeneration.from_pretrained(
127
+ "openai/whisper-small.en"
128
+ )
129
+
130
+ self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"}
131
+
132
+ def match(self, input: JobInput) -> bool:
133
+ url = get_url(input.content.strip())
134
+ if url is None:
135
+ return False
136
+
137
+ parsed = urllib3.util.parse_url(url)
138
+ if parsed.host not in self.hosts:
139
+ return False
140
+
141
+ self.url = url
142
+ return True
143
+
144
+ @staticmethod
145
+ def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor:
146
+ """Create batches from last dimension, pad last batch if necessary
147
+
148
+ Examples
149
+ >>> import torch
150
+ >>> x = torch.zeros((1, 10, 213))
151
+ >>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape
152
+ torch.Size([3, 10, 100])
153
+
154
+ """
155
+ # ugly workaround, transformers whisper implementation requires a
156
+ # specific shape of input length, probably there is a better way...
157
+ batches = input_ids.split(max_len, dim=-1) # type: ignore
158
+ last = batches[-1]
159
+ n = last.shape[-1]
160
+ last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0)
161
+ batches = batches[:-1] + (last,)
162
+ return torch.concat(batches)
163
+
164
+ def process(self, input: JobInput) -> str:
165
+ if not isinstance(self.url, str):
166
+ raise TypeError("self.url must be a string")
167
+
168
+ config = get_config()
169
+ fname = download_yt_audio(self.url, max_length=config.max_yt_length)
170
+ audio = load_audio(fname, sampling_rate=config.sampling_rate)
171
+ inputs = self.processor(
172
+ audio,
173
+ return_tensors='pt',
174
+ sampling_rate=config.sampling_rate,
175
+ max_length=-1,
176
+ )
177
+ batch = self.make_batch(
178
+ inputs['input_features'], max_len=2 * self.model.config.max_source_positions
179
+ )
180
+ generated_ids = self.model.generate(batch)
181
+ transcription = self.processor.batch_decode(
182
+ generated_ids, skip_special_tokens=True
183
+ )
184
+ return self.template.format(url=self.url, content=" ".join(transcription))
src/gistillery/registry.py CHANGED
@@ -1,11 +1,12 @@
1
  from gistillery.base import JobInput
2
- from gistillery.tools import Summarizer, Tagger, HfDefaultSummarizer, HfDefaultTagger
3
  from gistillery.preprocessing import (
 
 
4
  Processor,
5
  RawTextProcessor,
6
- ImageUrlProcessor,
7
- DefaultUrlProcessor,
8
  )
 
9
 
10
 
11
  class ToolRegistry:
@@ -57,6 +58,7 @@ def get_tool_registry() -> ToolRegistry:
57
  tagger = HfDefaultTagger()
58
 
59
  _registry = ToolRegistry()
 
60
  _registry.register_processor(ImageUrlProcessor())
61
  _registry.register_processor(DefaultUrlProcessor())
62
  _registry.register_processor(RawTextProcessor())
 
1
  from gistillery.base import JobInput
 
2
  from gistillery.preprocessing import (
3
+ DefaultUrlProcessor,
4
+ ImageUrlProcessor,
5
  Processor,
6
  RawTextProcessor,
7
+ YoutubeUrlProcessor,
 
8
  )
9
+ from gistillery.tools import HfDefaultSummarizer, HfDefaultTagger, Summarizer, Tagger
10
 
11
 
12
  class ToolRegistry:
 
58
  tagger = HfDefaultTagger()
59
 
60
  _registry = ToolRegistry()
61
+ _registry.register_processor(YoutubeUrlProcessor())
62
  _registry.register_processor(ImageUrlProcessor())
63
  _registry.register_processor(DefaultUrlProcessor())
64
  _registry.register_processor(RawTextProcessor())