aka7774's picture
Update fn.py
e68b7ad verified
raw
history blame
No virus
1.48 kB
import json
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
# config
model_id = "kotoba-tech/kotoba-whisper-v1.0"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = None
pipe = None
initial_prompt = None
def load_model():
global model, pipe
# load model
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
def set_prompt(prompt):
global initial_prompt
initial_prompt = prompt
def speech_to_text(audio_file, _model_size = None):
global model, pipe, initial_prompt
if not model:
load_model()
# run inference
generate_kwargs = {}
if initial_prompt:
generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(initial_prompt, return_tensors="pt").to(device)
result = pipe(audio_file, generate_kwargs=generate_kwargs)
try:
res = json.dumps(result)
except:
res = ''
return result["text"], res