Spaces:
Runtime error
Runtime error
File size: 4,394 Bytes
209aa14 474feff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import whisper
import sys
import threading
from typing import List, Union
import tqdm
class ProgressListenerHandle:
def __init__(self, listener):
self.listener = listener
def __enter__(self):
register_thread_local_progress_listener(self.listener)
def __exit__(self, exc_type, exc_val, exc_tb):
unregister_thread_local_progress_listener(self.listener)
if exc_type is None:
self.listener.on_finished()
class _CustomProgressBar(tqdm.tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current = self.n # Set the initial value
def update(self, n):
super().update(n)
# Because the progress bar might be disabled, we need to manually update the progress
self._current += n
# Inform listeners
listeners = _get_thread_local_listeners()
for listener in listeners:
listener.on_progress(self._current, self.total)
_thread_local = threading.local()
def _get_thread_local_listeners():
if not hasattr(_thread_local, 'listeners'):
_thread_local.listeners = []
return _thread_local.listeners
_hooked = False
def init_progress_hook():
global _hooked
if _hooked:
return
# Inject into tqdm.tqdm of Whisper, so we can see progress
import whisper.transcribe
transcribe_module = sys.modules['whisper.transcribe']
transcribe_module.tqdm.tqdm = _CustomProgressBar
_hooked = True
def register_thread_local_progress_listener(progress_listener):
# This is a workaround for the fact that the progress bar is not exposed in the API
init_progress_hook()
listeners = _get_thread_local_listeners()
listeners.append(progress_listener)
def unregister_thread_local_progress_listener(progress_listener):
listeners = _get_thread_local_listeners()
if progress_listener in listeners:
listeners.remove(progress_listener)
def create_progress_listener_handle(progress_listener):
return ProgressListenerHandle(progress_listener)
class PrintingProgressListener:
def __init__(self, progress):
self.progress = progress
def on_progress(self, current: Union[int, float], total: Union[int, float]):
self.progress(current / total, desc="Transcribing")
print(f"Progress: {current}/{total}")
def on_finished(self):
self.progress(1, desc="Transcribed!")
print("Finished")
import gc
import torch
from whisper.utils import get_writer
from random import random
models = ['base', 'small', 'medium', 'large']
output_formats = ["txt", "vtt", "srt", "tsv", "json"]
locModeltype = ""
locModel = None
def transcribe_audio(model,audio, progress=gr.Progress()):
global locModel
global locModeltype
try:
progress(0, desc="Starting...")
# If using a different model unload previous and load in a new one
if locModeltype != model:
locModeltype = model
del locModel
torch.cuda.empty_cache()
gc.collect()
progress(0, desc="Loading model...")
locModel = whisper.load_model(model)
progress(0, desc="Transcribing")
with create_progress_listener_handle(PrintingProgressListener(progress)) as listener:
result = locModel.transcribe(audio, verbose=False)
#path = f"/tmp/{oformat}{random()}"
#writr = get_writer(oformat, path)
#writr(result, path)
#with open(path, 'r') as f:
# rz = f.read()
# if rz == None:
# rz = result['text']
return f"language: {result['language']}\n\n{result['text']}"
except Exception as w:
raise gr.Error(f"Error: {str(w)}")
demo = gr.Interface(
fn=transcribe_audio,
inputs=[
gr.Dropdown(models, value=models[2], label="Model size", info="Model size determines the accuracy of the output text at the cost of speed"),
# gr.Dropdown(output_formats, value=output_formats[0], label="Output format", info="Format output text"),
# gr.Checkbox(value=False, label="Timestamps", info="Add timestampts to know when what was said"),
gr.Audio(label="Audio to transcribe",source='upload',type="filepath")
],
allow_flagging="never",
outputs="text")
demo.queue().launch() |