aadnk commited on
Commit
3fadc6e
1 Parent(s): 123e71d

Add support for downloading the results.

Browse files
Files changed (2) hide show
  1. app.py +45 -6
  2. utils.py +51 -2
app.py CHANGED
@@ -1,7 +1,11 @@
1
  from io import StringIO
 
 
 
 
2
  import gradio as gr
3
 
4
- from utils import write_vtt
5
  import whisper
6
 
7
  import ffmpeg
@@ -40,6 +44,8 @@ class UI:
40
 
41
  def transcribeFile(self, modelName, languageName, uploadFile, microphoneData, task):
42
  source = uploadFile if uploadFile is not None else microphoneData
 
 
43
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
44
  selectedModel = modelName if modelName is not None else "base"
45
 
@@ -56,14 +62,43 @@ class UI:
56
  model = whisper.load_model(selectedModel)
57
  model_cache[selectedModel] = model
58
 
 
59
  result = model.transcribe(source, language=selectedLanguage, task=task)
60
 
61
- segmentStream = StringIO()
62
- write_vtt(result["segments"], file=segmentStream)
63
- segmentStream.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return result["text"], segmentStream.read()
 
 
 
 
 
66
 
 
 
67
 
68
  def createUi(inputAudioMaxDuration, share=False):
69
  ui = UI(inputAudioMaxDuration)
@@ -81,7 +116,11 @@ def createUi(inputAudioMaxDuration, share=False):
81
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
82
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
83
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
84
- ], outputs=[gr.Text(label="Transcription"), gr.Text(label="Segments")])
 
 
 
 
85
 
86
  demo.launch(share=share)
87
 
 
1
  from io import StringIO
2
+ import os
3
+ import tempfile
4
+
5
+ from typing import Iterator
6
  import gradio as gr
7
 
8
+ from utils import slugify, write_srt, write_vtt
9
  import whisper
10
 
11
  import ffmpeg
 
44
 
45
  def transcribeFile(self, modelName, languageName, uploadFile, microphoneData, task):
46
  source = uploadFile if uploadFile is not None else microphoneData
47
+ sourceName = os.path.basename(source)
48
+
49
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
50
  selectedModel = modelName if modelName is not None else "base"
51
 
 
62
  model = whisper.load_model(selectedModel)
63
  model_cache[selectedModel] = model
64
 
65
+ # The results
66
  result = model.transcribe(source, language=selectedLanguage, task=task)
67
 
68
+ text = result["text"]
69
+ vtt = getSubs(result["segments"], "vtt")
70
+ srt = getSubs(result["segments"], "srt")
71
+
72
+ # Files that can be downloaded
73
+ downloadDirectory = tempfile.mkdtemp()
74
+ filePrefix = slugify(sourceName, allow_unicode=True)
75
+
76
+ download = []
77
+ download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
78
+ download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
79
+ download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
80
+
81
+ return text, vtt, download
82
+
83
+ def createFile(text: str, directory: str, fileName: str) -> str:
84
+ # Write the text to a file
85
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
86
+ file.write(text)
87
+
88
+ return file.name
89
+
90
+ def getSubs(segments: Iterator[dict], format: str) -> str:
91
+ segmentStream = StringIO()
92
 
93
+ if format == 'vtt':
94
+ write_vtt(segments, file=segmentStream)
95
+ elif format == 'srt':
96
+ write_srt(segments, file=segmentStream)
97
+ else:
98
+ raise Exception("Unknown format " + format)
99
 
100
+ segmentStream.seek(0)
101
+ return segmentStream.read()
102
 
103
  def createUi(inputAudioMaxDuration, share=False):
104
  ui = UI(inputAudioMaxDuration)
 
116
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
117
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
118
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
119
+ ], outputs=[
120
+ gr.Text(label="Transcription"),
121
+ gr.Text(label="Segments"),
122
+ gr.File(label="Download")
123
+ ])
124
 
125
  demo.launch(share=share)
126
 
utils.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import zlib
2
  from typing import Iterator, TextIO
3
 
@@ -27,7 +30,7 @@ def compression_ratio(text) -> float:
27
  return len(text) / len(zlib.compress(text.encode("utf-8")))
28
 
29
 
30
- def format_timestamp(seconds: float):
31
  assert seconds >= 0, "non-negative timestamp expected"
32
  milliseconds = round(seconds * 1000.0)
33
 
@@ -40,7 +43,13 @@ def format_timestamp(seconds: float):
40
  seconds = milliseconds // 1_000
41
  milliseconds -= seconds * 1_000
42
 
43
- return (f"{hours}:" if hours > 0 else "") + f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
 
 
 
 
 
 
44
 
45
 
46
  def write_vtt(transcript: Iterator[dict], file: TextIO):
@@ -52,3 +61,43 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
52
  file=file,
53
  flush=True,
54
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+ import re
3
+
4
  import zlib
5
  from typing import Iterator, TextIO
6
 
 
30
  return len(text) / len(zlib.compress(text.encode("utf-8")))
31
 
32
 
33
+ def format_timestamp(seconds: float, always_include_hours: bool = False):
34
  assert seconds >= 0, "non-negative timestamp expected"
35
  milliseconds = round(seconds * 1000.0)
36
 
 
43
  seconds = milliseconds // 1_000
44
  milliseconds -= seconds * 1_000
45
 
46
+ hours_marker = f"{hours}:" if always_include_hours or hours > 0 else ""
47
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
48
+
49
+
50
+ def write_txt(transcript: Iterator[dict], file: TextIO):
51
+ for segment in transcript:
52
+ print(segment['text'].strip(), file=file, flush=True)
53
 
54
 
55
  def write_vtt(transcript: Iterator[dict], file: TextIO):
 
61
  file=file,
62
  flush=True,
63
  )
64
+
65
+
66
+ def write_srt(transcript: Iterator[dict], file: TextIO):
67
+ """
68
+ Write a transcript to a file in SRT format.
69
+ Example usage:
70
+ from pathlib import Path
71
+ from whisper.utils import write_srt
72
+ result = transcribe(model, audio_path, temperature=temperature, **args)
73
+ # save SRT
74
+ audio_basename = Path(audio_path).stem
75
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
76
+ write_srt(result["segments"], file=srt)
77
+ """
78
+ for i, segment in enumerate(transcript, start=1):
79
+ # write srt lines
80
+ print(
81
+ f"{i}\n"
82
+ f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
83
+ f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
84
+ f"{segment['text'].strip().replace('-->', '->')}\n",
85
+ file=file,
86
+ flush=True,
87
+ )
88
+
89
+ def slugify(value, allow_unicode=False):
90
+ """
91
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
92
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
93
+ dashes to single dashes. Remove characters that aren't alphanumerics,
94
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
95
+ trailing whitespace, dashes, and underscores.
96
+ """
97
+ value = str(value)
98
+ if allow_unicode:
99
+ value = unicodedata.normalize('NFKC', value)
100
+ else:
101
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
102
+ value = re.sub(r'[^\w\s-]', '', value.lower())
103
+ return re.sub(r'[-\s]+', '-', value).strip('-_')