aadnk commited on
Commit
6a308c6
1 Parent(s): 1a68fc3

Let max line width depend on the language

Browse files
Files changed (2) hide show
  1. app.py +24 -9
  2. utils.py +9 -6
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Iterator
2
 
3
  from io import StringIO
@@ -15,15 +16,15 @@ from download import downloadUrl
15
 
16
  from utils import slugify, write_srt, write_vtt
17
 
18
- #import os
19
- #os.system("pip install git+https://github.com/openai/whisper.git")
20
-
21
  # Limitations (set to -1 to disable)
22
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
23
 
24
  # Whether or not to automatically delete all uploaded files, to save disk space
25
  DELETE_UPLOADED_FILES = True
26
 
 
 
 
27
  LANGUAGES = [
28
  "English", "Chinese", "German", "Spanish", "Russian", "Korean",
29
  "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
@@ -74,8 +75,13 @@ class UI:
74
  result = model.transcribe(source, language=selectedLanguage, task=task)
75
 
76
  text = result["text"]
77
- vtt = getSubs(result["segments"], "vtt")
78
- srt = getSubs(result["segments"], "srt")
 
 
 
 
 
79
 
80
  # Files that can be downloaded
81
  downloadDirectory = tempfile.mkdtemp()
@@ -95,6 +101,15 @@ class UI:
95
  os.remove(source)
96
 
97
 
 
 
 
 
 
 
 
 
 
98
  def getSource(urlData, uploadFile, microphoneData):
99
  if urlData:
100
  # Download from YouTube
@@ -104,7 +119,7 @@ def getSource(urlData, uploadFile, microphoneData):
104
  source = uploadFile if uploadFile is not None else microphoneData
105
 
106
  file_path = pathlib.Path(source)
107
- sourceName = file_path.stem[:18] + file_path.suffix
108
 
109
  return source, sourceName
110
 
@@ -115,13 +130,13 @@ def createFile(text: str, directory: str, fileName: str) -> str:
115
 
116
  return file.name
117
 
118
- def getSubs(segments: Iterator[dict], format: str) -> str:
119
  segmentStream = StringIO()
120
 
121
  if format == 'vtt':
122
- write_vtt(segments, file=segmentStream)
123
  elif format == 'srt':
124
- write_srt(segments, file=segmentStream)
125
  else:
126
  raise Exception("Unknown format " + format)
127
 
 
1
+ import re
2
  from typing import Iterator
3
 
4
  from io import StringIO
 
16
 
17
  from utils import slugify, write_srt, write_vtt
18
 
 
 
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
21
 
22
  # Whether or not to automatically delete all uploaded files, to save disk space
23
  DELETE_UPLOADED_FILES = True
24
 
25
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
26
+ MAX_FILE_PREFIX_LENGTH = 17
27
+
28
  LANGUAGES = [
29
  "English", "Chinese", "German", "Spanish", "Russian", "Korean",
30
  "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
 
75
  result = model.transcribe(source, language=selectedLanguage, task=task)
76
 
77
  text = result["text"]
78
+
79
+ language = result["language"]
80
+ languageMaxLineWidth = getMaxLineWidth(language)
81
+
82
+ print("Max line width " + str(languageMaxLineWidth))
83
+ vtt = getSubs(result["segments"], "vtt", languageMaxLineWidth)
84
+ srt = getSubs(result["segments"], "srt", languageMaxLineWidth)
85
 
86
  # Files that can be downloaded
87
  downloadDirectory = tempfile.mkdtemp()
 
101
  os.remove(source)
102
 
103
 
104
+ def getMaxLineWidth(language: str) -> int:
105
+ if (language == "ja" or language == "zh"):
106
+ # Chinese characters and kana are wider, so limit line length to 40 characters
107
+ return 40
108
+ else:
109
+ # TODO: Add more languages
110
+ # 80 latin characters should fit on a 1080p/720p screen
111
+ return 80
112
+
113
  def getSource(urlData, uploadFile, microphoneData):
114
  if urlData:
115
  # Download from YouTube
 
119
  source = uploadFile if uploadFile is not None else microphoneData
120
 
121
  file_path = pathlib.Path(source)
122
+ sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
123
 
124
  return source, sourceName
125
 
 
130
 
131
  return file.name
132
 
133
+ def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
134
  segmentStream = StringIO()
135
 
136
  if format == 'vtt':
137
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
138
  elif format == 'srt':
139
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
140
  else:
141
  raise Exception("Unknown format " + format)
142
 
utils.py CHANGED
@@ -53,10 +53,10 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
53
  print(segment['text'].strip(), file=file, flush=True)
54
 
55
 
56
- def write_vtt(transcript: Iterator[dict], file: TextIO):
57
  print("WEBVTT\n", file=file)
58
  for segment in transcript:
59
- text = processText(segment['text']).replace('-->', '->')
60
 
61
  print(
62
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
@@ -66,7 +66,7 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
66
  )
67
 
68
 
69
- def write_srt(transcript: Iterator[dict], file: TextIO):
70
  """
71
  Write a transcript to a file in SRT format.
72
  Example usage:
@@ -79,7 +79,7 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
79
  write_srt(result["segments"], file=srt)
80
  """
81
  for i, segment in enumerate(transcript, start=1):
82
- text = processText(segment['text'].strip()).replace('-->', '->')
83
 
84
  # write srt lines
85
  print(
@@ -91,8 +91,11 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
91
  flush=True,
92
  )
93
 
94
- def processText(text: str):
95
- lines = textwrap.wrap(text, width=47, tabsize=4)
 
 
 
96
  return '\n'.join(lines)
97
 
98
  def slugify(value, allow_unicode=False):
 
53
  print(segment['text'].strip(), file=file, flush=True)
54
 
55
 
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
  print("WEBVTT\n", file=file)
58
  for segment in transcript:
59
+ text = processText(segment['text'], maxLineWidth).replace('-->', '->')
60
 
61
  print(
62
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
 
66
  )
67
 
68
 
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
  """
71
  Write a transcript to a file in SRT format.
72
  Example usage:
 
79
  write_srt(result["segments"], file=srt)
80
  """
81
  for i, segment in enumerate(transcript, start=1):
82
+ text = processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
 
84
  # write srt lines
85
  print(
 
91
  flush=True,
92
  )
93
 
94
+ def processText(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
  return '\n'.join(lines)
100
 
101
  def slugify(value, allow_unicode=False):