Spaces:
No application file
No application file
import gc | |
import os.path | |
from tempfile import NamedTemporaryFile | |
import torch | |
import whisper | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline | |
from gradio_client.client import DEFAULT_TEMP_DIR | |
processor: WhisperProcessor = None | |
model: WhisperForConditionalGeneration | AutomaticSpeechRecognitionPipeline = None | |
device: str = None | |
loaded_model: str = None | |
def get_official_models(): | |
# return [ | |
# 'openai/whisper-tiny.en', | |
# 'openai/whisper-small.en', | |
# 'openai/whisper-base.en', | |
# 'openai/whisper-medium.en', | |
# 'openai/whisper-tiny', | |
# 'openai/whisper-small', | |
# 'openai/whisper-base', | |
# 'openai/whisper-medium', | |
# 'openai/whisper-large', | |
# 'openai/whisper-large-v2' | |
# ] | |
return [ | |
'tiny.en', | |
'small.en', | |
'base.en', | |
'medium.en', | |
'tiny', | |
'small', | |
'base', | |
'medium', | |
'large', | |
'large-v2' | |
] | |
def unload(): | |
global model, processor, device, loaded_model | |
model = None | |
processor = None | |
device = None | |
loaded_model = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
return 'Unloaded' | |
def load(pretrained_model='openai/whisper-base', map_device='cuda' if torch.cuda.is_available() else 'cpu'): | |
global model, processor, device, loaded_model | |
try: | |
if loaded_model != pretrained_model: | |
unload() | |
# model = pipeline('automatic-speech-recognition', pretrained_model, device=map_device, model_kwargs={'cache_dir': 'models/automatic-speech-recognition'}) | |
model = whisper.load_model(pretrained_model, map_device, 'data/models/automatic-speech-recognition/whisper') | |
loaded_model = pretrained_model | |
device = map_device | |
return f'Loaded {pretrained_model}' | |
except Exception as e: | |
unload() | |
return f'Failed to load, {e}' | |
def transcribe(wav, files) -> tuple[tuple[int, torch.Tensor], list[str]]: | |
return transcribe_wav(wav), transcribe_files(files) | |
def transcribe_wav(wav): | |
global model, processor, device, loaded_model | |
if loaded_model is not None: | |
if wav is None: | |
return None | |
sr, wav = wav | |
import traceback | |
try: | |
if sr != 16000: | |
import torchaudio.functional as F | |
wav = F.resample((torch.tensor(wav).to(device).float() / 32767.0).mean(-1).squeeze().unsqueeze(0), sr, 16000).flatten().cpu().detach().numpy() | |
sr = 16000 | |
return whisper.transcribe(model, wav)['text'].strip() | |
except Exception as e: | |
traceback.print_exception(e) | |
return f'Exception: {e}' | |
else: | |
return 'No model loaded! Please load a model.' | |
def transcribe_files(files: list) -> list[str]: | |
if files is None or len(files) == 0: | |
return [] | |
out_list = [] | |
global model, processor, device, loaded_model | |
if loaded_model is not None: | |
for f in files: | |
filename = os.path.basename(f.name) | |
print('Processing ', filename) | |
filename_noext, fileext = os.path.splitext(filename) | |
out_file = NamedTemporaryFile(dir=DEFAULT_TEMP_DIR, mode='w', delete=False, suffix='.txt', prefix=filename_noext, encoding='utf8') | |
out_file.write(whisper.transcribe(model, f.name)['text'].strip()) | |
out_list.append(out_file.name) | |
return out_list | |
else: | |
return [] | |