mrtroydev's picture
Upload folder using huggingface_hub
3883c60 verified
raw
history blame contribute delete
No virus
3.53 kB
import gc
import os.path
from tempfile import NamedTemporaryFile
import torch
import whisper
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
from gradio_client.client import DEFAULT_TEMP_DIR
processor: WhisperProcessor = None
model: WhisperForConditionalGeneration | AutomaticSpeechRecognitionPipeline = None
device: str = None
loaded_model: str = None
def get_official_models():
# return [
# 'openai/whisper-tiny.en',
# 'openai/whisper-small.en',
# 'openai/whisper-base.en',
# 'openai/whisper-medium.en',
# 'openai/whisper-tiny',
# 'openai/whisper-small',
# 'openai/whisper-base',
# 'openai/whisper-medium',
# 'openai/whisper-large',
# 'openai/whisper-large-v2'
# ]
return [
'tiny.en',
'small.en',
'base.en',
'medium.en',
'tiny',
'small',
'base',
'medium',
'large',
'large-v2'
]
def unload():
global model, processor, device, loaded_model
model = None
processor = None
device = None
loaded_model = None
gc.collect()
torch.cuda.empty_cache()
return 'Unloaded'
def load(pretrained_model='openai/whisper-base', map_device='cuda' if torch.cuda.is_available() else 'cpu'):
global model, processor, device, loaded_model
try:
if loaded_model != pretrained_model:
unload()
# model = pipeline('automatic-speech-recognition', pretrained_model, device=map_device, model_kwargs={'cache_dir': 'models/automatic-speech-recognition'})
model = whisper.load_model(pretrained_model, map_device, 'data/models/automatic-speech-recognition/whisper')
loaded_model = pretrained_model
device = map_device
return f'Loaded {pretrained_model}'
except Exception as e:
unload()
return f'Failed to load, {e}'
def transcribe(wav, files) -> tuple[tuple[int, torch.Tensor], list[str]]:
return transcribe_wav(wav), transcribe_files(files)
def transcribe_wav(wav):
global model, processor, device, loaded_model
if loaded_model is not None:
if wav is None:
return None
sr, wav = wav
import traceback
try:
if sr != 16000:
import torchaudio.functional as F
wav = F.resample((torch.tensor(wav).to(device).float() / 32767.0).mean(-1).squeeze().unsqueeze(0), sr, 16000).flatten().cpu().detach().numpy()
sr = 16000
return whisper.transcribe(model, wav)['text'].strip()
except Exception as e:
traceback.print_exception(e)
return f'Exception: {e}'
else:
return 'No model loaded! Please load a model.'
def transcribe_files(files: list) -> list[str]:
if files is None or len(files) == 0:
return []
out_list = []
global model, processor, device, loaded_model
if loaded_model is not None:
for f in files:
filename = os.path.basename(f.name)
print('Processing ', filename)
filename_noext, fileext = os.path.splitext(filename)
out_file = NamedTemporaryFile(dir=DEFAULT_TEMP_DIR, mode='w', delete=False, suffix='.txt', prefix=filename_noext, encoding='utf8')
out_file.write(whisper.transcribe(model, f.name)['text'].strip())
out_list.append(out_file.name)
return out_list
else:
return []