aadnk commited on
Commit
fdd892b
1 Parent(s): 6a308c6

Limit video length before downloading from YouTube

Browse files
Files changed (2) hide show
  1. app.py +53 -50
  2. download.py +16 -2
app.py CHANGED
@@ -12,7 +12,7 @@ import ffmpeg
12
 
13
  # UI
14
  import gradio as gr
15
- from download import downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
18
 
@@ -52,54 +52,70 @@ class UI:
52
  self.inputAudioMaxDuration = inputAudioMaxDuration
53
 
54
  def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task):
55
- source, sourceName = getSource(urlData, uploadFile, microphoneData)
56
-
57
  try:
58
- selectedLanguage = languageName.lower() if len(languageName) > 0 else None
59
- selectedModel = modelName if modelName is not None else "base"
60
-
61
- if self.inputAudioMaxDuration > 0:
62
- # Calculate audio length
63
- audioDuration = ffmpeg.probe(source)["format"]["duration"]
 
 
 
 
 
 
 
 
64
 
65
- if float(audioDuration) > self.inputAudioMaxDuration:
66
- return ("[ERROR]: Maximum audio file length is " + str(self.inputAudioMaxDuration) + "s, file was " + str(audioDuration) + "s"), "[ERROR]"
 
67
 
68
- model = model_cache.get(selectedModel, None)
69
-
70
- if not model:
71
- model = whisper.load_model(selectedModel)
72
- model_cache[selectedModel] = model
73
 
74
- # The results
75
- result = model.transcribe(source, language=selectedLanguage, task=task)
76
 
77
- text = result["text"]
 
 
78
 
79
- language = result["language"]
80
- languageMaxLineWidth = getMaxLineWidth(language)
 
81
 
82
- print("Max line width " + str(languageMaxLineWidth))
83
- vtt = getSubs(result["segments"], "vtt", languageMaxLineWidth)
84
- srt = getSubs(result["segments"], "srt", languageMaxLineWidth)
 
85
 
86
- # Files that can be downloaded
87
- downloadDirectory = tempfile.mkdtemp()
88
- filePrefix = slugify(sourceName, allow_unicode=True)
89
 
90
- download = []
91
- download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
92
- download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
93
- download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
 
 
 
 
94
 
95
- return download, text, vtt
 
 
 
 
 
 
96
 
97
- finally:
98
- # Cleanup source
99
- if DELETE_UPLOADED_FILES:
100
- print("Deleting source file " + source)
101
- os.remove(source)
102
 
 
103
 
104
  def getMaxLineWidth(language: str) -> int:
105
  if (language == "ja" or language == "zh"):
@@ -110,19 +126,6 @@ def getMaxLineWidth(language: str) -> int:
110
  # 80 latin characters should fit on a 1080p/720p screen
111
  return 80
112
 
113
- def getSource(urlData, uploadFile, microphoneData):
114
- if urlData:
115
- # Download from YouTube
116
- source = downloadUrl(urlData)
117
- else:
118
- # File input
119
- source = uploadFile if uploadFile is not None else microphoneData
120
-
121
- file_path = pathlib.Path(source)
122
- sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
123
-
124
- return source, sourceName
125
-
126
  def createFile(text: str, directory: str, fileName: str) -> str:
127
  # Write the text to a file
128
  with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
 
12
 
13
  # UI
14
  import gradio as gr
15
+ from download import ExceededMaximumDuration, downloadUrl
16
 
17
  from utils import slugify, write_srt, write_vtt
18
 
 
52
  self.inputAudioMaxDuration = inputAudioMaxDuration
53
 
54
  def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task):
 
 
55
  try:
56
+ source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
57
+
58
+ try:
59
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
60
+ selectedModel = modelName if modelName is not None else "base"
61
+
62
+ if self.inputAudioMaxDuration > 0:
63
+ # Calculate audio length
64
+ audioDuration = ffmpeg.probe(source)["format"]["duration"]
65
+
66
+ if float(audioDuration) > self.inputAudioMaxDuration:
67
+ return [], ("[ERROR]: Maximum audio file length is " + str(self.inputAudioMaxDuration) + "s, file was " + str(audioDuration) + "s"), "[ERROR]"
68
+
69
+ model = model_cache.get(selectedModel, None)
70
 
71
+ if not model:
72
+ model = whisper.load_model(selectedModel)
73
+ model_cache[selectedModel] = model
74
 
75
+ # The results
76
+ result = model.transcribe(source, language=selectedLanguage, task=task)
77
+
78
+ text = result["text"]
 
79
 
80
+ language = result["language"]
81
+ languageMaxLineWidth = getMaxLineWidth(language)
82
 
83
+ print("Max line width " + str(languageMaxLineWidth))
84
+ vtt = getSubs(result["segments"], "vtt", languageMaxLineWidth)
85
+ srt = getSubs(result["segments"], "srt", languageMaxLineWidth)
86
 
87
+ # Files that can be downloaded
88
+ downloadDirectory = tempfile.mkdtemp()
89
+ filePrefix = slugify(sourceName, allow_unicode=True)
90
 
91
+ download = []
92
+ download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
93
+ download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
94
+ download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
95
 
96
+ return download, text, vtt
 
 
97
 
98
+ finally:
99
+ # Cleanup source
100
+ if DELETE_UPLOADED_FILES:
101
+ print("Deleting source file " + source)
102
+ os.remove(source)
103
+
104
+ except ExceededMaximumDuration as e:
105
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
106
 
107
+ def getSource(self, urlData, uploadFile, microphoneData):
108
+ if urlData:
109
+ # Download from YouTube
110
+ source = downloadUrl(urlData, self.inputAudioMaxDuration)
111
+ else:
112
+ # File input
113
+ source = uploadFile if uploadFile is not None else microphoneData
114
 
115
+ file_path = pathlib.Path(source)
116
+ sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
 
 
 
117
 
118
+ return source, sourceName
119
 
120
  def getMaxLineWidth(language: str) -> int:
121
  if (language == "ja" or language == "zh"):
 
126
  # 80 latin characters should fit on a 1080p/720p screen
127
  return 80
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def createFile(text: str, directory: str, fileName: str) -> str:
130
  # Write the text to a file
131
  with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
download.py CHANGED
@@ -13,7 +13,7 @@ class FilenameCollectorPP(PostProcessor):
13
  self.filenames.append(information["filepath"])
14
  return [], information
15
 
16
- def downloadUrl(url: str):
17
  destinationDirectory = mkdtemp()
18
 
19
  ydl_opts = {
@@ -26,6 +26,13 @@ def downloadUrl(url: str):
26
  filename_collector = FilenameCollectorPP()
27
 
28
  with YoutubeDL(ydl_opts) as ydl:
 
 
 
 
 
 
 
29
  ydl.add_post_processor(filename_collector)
30
  ydl.download([url])
31
 
@@ -35,4 +42,11 @@ def downloadUrl(url: str):
35
  result = filename_collector.filenames[0]
36
  print("Downloaded " + result)
37
 
38
- return result
 
 
 
 
 
 
 
 
13
  self.filenames.append(information["filepath"])
14
  return [], information
15
 
16
+ def downloadUrl(url: str, maxDuration: int = None):
17
  destinationDirectory = mkdtemp()
18
 
19
  ydl_opts = {
 
26
  filename_collector = FilenameCollectorPP()
27
 
28
  with YoutubeDL(ydl_opts) as ydl:
29
+ if maxDuration:
30
+ info = ydl.extract_info(url, download=False)
31
+ duration = info['duration']
32
+
33
+ if duration >= maxDuration:
34
+ raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
35
+
36
  ydl.add_post_processor(filename_collector)
37
  ydl.download([url])
38
 
 
42
  result = filename_collector.filenames[0]
43
  print("Downloaded " + result)
44
 
45
+ return result
46
+
47
+
48
+ class ExceededMaximumDuration(Exception):
49
+ def __init__(self, videoDuration, maxDuration, message):
50
+ self.videoDuration = videoDuration
51
+ self.maxDuration = maxDuration
52
+ super().__init__(message)