Spaces:
Sleeping
Sleeping
File size: 9,738 Bytes
cd814fd c42dafc cd814fd c42dafc cd814fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import whisper
import gradio as gr
import os
from pytube import YouTube
class WhisperModelUI(object):
def __init__(self, ui_obj):
self.name = "Whisper Model Processor UI"
self.description = "This class is designed to build UI for our Whisper Model"
self.ui_obj = ui_obj
self.audio_files_list = ['No content']
self.whisper_model = whisper.model.Whisper
self.video_store_path = 'data_files'
def load_content(self, file_list):
video_out_path = os.path.join(os.getcwd(), self.video_store_path)
self.audio_files_list = [f for f in os.listdir(video_out_path)
if os.path.isfile(video_out_path + "/" + f)
and (f.endswith(".mp4") or f.endswith('mp3'))]
return gr.Dropdown.update(choices=self.audio_files_list)
def load_whisper_model(self, model_type):
try:
asr_model = whisper.load_model(model_type.lower())
self.whisper_model = asr_model
status = "{} ロード完了".format(model_type)
except:
status = "ロードエラー {} model".format(model_type)
return status, str(self.whisper_model)
def load_youtube_video(self, video_url):
video_out_path = os.path.join(os.getcwd(), self.video_store_path)
yt = YouTube(video_url)
local_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by(
'resolution').desc().first().download(video_out_path)
return local_video_path
def get_video_to_text(self,
transcribe_or_decode,
video_list_dropdown_file_name,
language_detect,
translate_or_transcribe
):
debug_text = ""
try:
video_out_path = os.path.join(os.getcwd(), 'data_files')
video_full_path = os.path.join(video_out_path, video_list_dropdown_file_name)
if not os.path.isfile(video_full_path):
video_text = "Selected video/audio is could not be located.."
else:
video_text = "Bad choice or result.."
if transcribe_or_decode == 'Transcribe':
video_text, debug_text = self.run_asr_with_transcribe(video_full_path, language_detect,
translate_or_transcribe)
elif transcribe_or_decode == 'Decode':
audio = whisper.load_audio(video_full_path)
video_text, debug_text = self.run_asr_with_decode(audio, language_detect,
translate_or_transcribe)
except:
video_text = "Error processing audio..."
return video_text, debug_text
def run_asr_with_decode(self, audio, language_detect, translate_or_transcribe):
debug_info = "None.."
if 'encoder' not in dir(self.whisper_model) or 'decoder' not in dir(self.whisper_model):
return "Model is not loaded, please load the model first", debug_info
if self.whisper_model.encoder is None or self.whisper_model.decoder is None:
return "Model is not loaded, please load the model first", debug_info
try:
# pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(self.whisper_model.device)
if language_detect == 'Detect':
# detect the spoken language
_, probs = self.whisper_model.detect_language(mel)
# print(f"Detected language: {max(probs, key=probs.get)}")
# decode the audio
# mps crash if fp16=False is not used
task_type = 'transcribe'
if translate_or_transcribe == 'Translate':
task_type = 'translate'
if language_detect != 'Detect':
options = whisper.DecodingOptions(fp16=False,
language=language_detect,
task=task_type)
else:
options = whisper.DecodingOptions(fp16=False,
task=task_type)
result = whisper.decode(self.whisper_model, mel, options)
result_text = result.text
debug_info = str(result)
except:
result_text = "Error handing audio to text.."
return result_text, debug_info
def run_asr_with_transcribe(self, audio_path, language_detect, translate_or_transcribe):
result_text = "Error..."
debug_info = "None.."
if 'encoder' not in dir(self.whisper_model) or 'decoder' not in dir(self.whisper_model):
return "Model is not loaded, please load the model first", debug_info
if self.whisper_model.encoder is None or self.whisper_model.decoder is None:
return "Model is not loaded, please load the model first", debug_info
task_type = 'transcribe'
if translate_or_transcribe == 'Translate':
task_type = 'translate'
transcribe_options = dict(beam_size=5, best_of=5,
fp16=False,
task=task_type,
without_timestamps=False)
if language_detect != 'Detect':
transcribe_options['language'] = language_detect
transcription = self.whisper_model.transcribe(audio_path, **transcribe_options)
if transcription is not None:
result_text = transcription['text']
debug_info = str(transcription)
return result_text, debug_info
def create_whisper_ui(self):
with self.ui_obj:
gr.Markdown("AI翻訳・書き起こし")
with gr.Tabs():
with gr.TabItem("YouTubeURLから"):
with gr.Row():
with gr.Column():
asr_model_type = gr.Radio(['Tiny', 'Base', 'Small', 'Medium', 'Large'],
label="モデルタイプ(精度)",
value='Base'
)
model_status_lbl = gr.Label(label="ローディングステータス")
load_model_btn = gr.Button("モデルをロード")
youtube_url = gr.Textbox(label="YouTube URL",
# value="https://www.youtube.com/watch?v=Y2nHd7El8iw"
value=""
)
youtube_video = gr.Video(label="ビデオ")
get_video_btn = gr.Button("YouTubeURLをロード")
with gr.Column():
video_list_dropdown = gr.Dropdown(self.audio_files_list, label="保存済みビデオ")
load_video_list_btn = gr.Button("全てのビデオをロード")
transcribe_or_decode = gr.Radio(['Transcribe', 'Decode'],
label="オプション(Transcribe = 書き起こし)",
value='Transcribe'
)
language_detect = gr.Dropdown(['Detect', 'English', 'Hindi', 'Japanese'],
label="自動検知か言語を選択")
translate_or_transcribe = gr.Dropdown(['Transcribe', 'Translate'],
label="Translate(翻訳)か Transcribe(書き起こし)を選択")
get_video_txt_btn = gr.Button("変換開始!")
video_text = gr.Textbox(label="テキスト", lines=10)
with gr.TabItem("デバッグ情報"):
with gr.Row():
with gr.Column():
debug_text = gr.Textbox(label="Debug Details", lines=20)
load_model_btn.click(
self.load_whisper_model,
[
asr_model_type
],
[
model_status_lbl,
debug_text
]
)
get_video_btn.click(
self.load_youtube_video,
[
youtube_url
],
[
youtube_video
]
)
load_video_list_btn.click(
self.load_content,
[
video_list_dropdown
],
[
video_list_dropdown
]
)
get_video_txt_btn.click(
self.get_video_to_text,
[
transcribe_or_decode,
video_list_dropdown,
language_detect,
translate_or_transcribe
],
[
video_text,
debug_text
]
)
def launch_ui(self):
self.ui_obj.launch(debug=True)
|