File size: 4,394 Bytes
209aa14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474feff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import whisper
import sys
import threading
from typing import List, Union
import tqdm


class ProgressListenerHandle:
    def __init__(self, listener):
        self.listener = listener
    
    def __enter__(self):
        register_thread_local_progress_listener(self.listener)

    def __exit__(self, exc_type, exc_val, exc_tb):
        unregister_thread_local_progress_listener(self.listener)
        
        if exc_type is None:
            self.listener.on_finished()

class _CustomProgressBar(tqdm.tqdm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._current = self.n  # Set the initial value

    def update(self, n):
        super().update(n)
        # Because the progress bar might be disabled, we need to manually update the progress
        self._current += n

        # Inform listeners
        listeners = _get_thread_local_listeners()

        for listener in listeners:
            listener.on_progress(self._current, self.total)

_thread_local = threading.local()

def _get_thread_local_listeners():
    if not hasattr(_thread_local, 'listeners'):
        _thread_local.listeners = []
    return _thread_local.listeners

_hooked = False

def init_progress_hook():
    global _hooked

    if _hooked:
        return

    # Inject into tqdm.tqdm of Whisper, so we can see progress
    import whisper.transcribe 
    transcribe_module = sys.modules['whisper.transcribe']
    transcribe_module.tqdm.tqdm = _CustomProgressBar
    _hooked = True

def register_thread_local_progress_listener(progress_listener):
    # This is a workaround for the fact that the progress bar is not exposed in the API
    init_progress_hook()

    listeners = _get_thread_local_listeners()
    listeners.append(progress_listener)

def unregister_thread_local_progress_listener(progress_listener):
    listeners = _get_thread_local_listeners()
    
    if progress_listener in listeners:
        listeners.remove(progress_listener)

def create_progress_listener_handle(progress_listener):
    return ProgressListenerHandle(progress_listener)

class PrintingProgressListener:
    def __init__(self, progress):
        self.progress = progress

    def on_progress(self, current: Union[int, float], total: Union[int, float]):
        self.progress(current / total, desc="Transcribing")
        print(f"Progress: {current}/{total}")

    def on_finished(self):
        self.progress(1, desc="Transcribed!")
        print("Finished")


import gc
import torch
from whisper.utils import get_writer
from random import random
models = ['base', 'small', 'medium', 'large']
output_formats = ["txt", "vtt", "srt", "tsv", "json"]
locModeltype = ""
locModel = None
def transcribe_audio(model,audio, progress=gr.Progress()):
    global locModel
    global locModeltype
    try:
        progress(0, desc="Starting...")
        # If using a different model unload previous and load in a new one
        if locModeltype != model:
            locModeltype = model
            del locModel
            torch.cuda.empty_cache()
            gc.collect()
            progress(0, desc="Loading model...")
            locModel = whisper.load_model(model)

        progress(0, desc="Transcribing")

        with create_progress_listener_handle(PrintingProgressListener(progress)) as listener:
            result = locModel.transcribe(audio, verbose=False)
            #path = f"/tmp/{oformat}{random()}"
            #writr = get_writer(oformat, path)
            #writr(result, path)
            #with open(path, 'r') as f:
            #  rz = f.read()
            #  if rz == None:
            #    rz = result['text']
            return f"language: {result['language']}\n\n{result['text']}"
    except Exception as w:
        raise gr.Error(f"Error: {str(w)}")



demo = gr.Interface(
    fn=transcribe_audio, 
    inputs=[
        gr.Dropdown(models, value=models[2], label="Model size", info="Model size determines the accuracy of the output text at the cost of speed"),
        # gr.Dropdown(output_formats, value=output_formats[0], label="Output format", info="Format output text"),
        # gr.Checkbox(value=False, label="Timestamps", info="Add timestampts to know when what was said"),
        gr.Audio(label="Audio to transcribe",source='upload',type="filepath")
    ], 
    allow_flagging="never",
    outputs="text")
    
demo.queue().launch()