Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
·
8d1c4b4
1
Parent(s):
f279063
Create whisper_post_processor.py
Browse files- whisper_post_processor.py +46 -0
whisper_post_processor.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from interpreter import WhisperInterpreter
|
2 |
+
from utils import VIDEO_INFO, json_dump
|
3 |
+
from yt_dlp.postprocessor import PostProcessor
|
4 |
+
from datasets import Dataset
|
5 |
+
import re
|
6 |
+
|
7 |
+
class WhisperPP(PostProcessor):
|
8 |
+
def __init__(self,data,**whisper_options):
|
9 |
+
super().__init__()
|
10 |
+
self._options = whisper_options
|
11 |
+
interpreter = WhisperInterpreter(self._options.pop("model_size","base"))
|
12 |
+
self.data = data
|
13 |
+
self._process = getattr(interpreter, self._options.pop("mode","transcribe"))
|
14 |
+
self._write = self._options.pop("write")
|
15 |
+
self.videos_to_process = self._options.pop("number_videos",0)
|
16 |
+
self.repoId = self._get_name()
|
17 |
+
|
18 |
+
def run(self, info):
|
19 |
+
self.to_screen(f"Processing Video {info['id']}")
|
20 |
+
result = {key: info[key] for key in VIDEO_INFO}
|
21 |
+
result.update(self._process(info["filepath"], **self._options))
|
22 |
+
self.to_screen(f"Processed Video {info['id']} and appended results.")
|
23 |
+
self._update_data(result)
|
24 |
+
if self._write:
|
25 |
+
json_dump(result, f"{info['filepath'].split('.')[0]}.json")
|
26 |
+
return [], info
|
27 |
+
|
28 |
+
def _update_data(self, record):
|
29 |
+
dataType = type(self.data)
|
30 |
+
if dataType == list:
|
31 |
+
self.data.append(record)
|
32 |
+
else:
|
33 |
+
self.data = self.data.add_item(record)
|
34 |
+
if self.data.num_rows >= self.videos_to_process and self.videos_to_process != 0:
|
35 |
+
self.data.push_to_hub(self.repoId)
|
36 |
+
|
37 |
+
def get_data(self):
|
38 |
+
return self.data
|
39 |
+
|
40 |
+
def _get_name(self):
|
41 |
+
if self.data.info.download_checksums is not None:
|
42 |
+
regex = r"(?<=datasets\/)(.*?)(?=\/resolve)"
|
43 |
+
repoId = re.compile(regex)
|
44 |
+
url = list(self.data.info.download_checksums.keys())[0]
|
45 |
+
return repoId.findall(url)[0]
|
46 |
+
return ""
|