fish-speech-1 / tools /auto_rerank.py
PoTaTo721's picture
update to 1.2
69e8a46
import os
os.environ["MODELSCOPE_CACHE"] = ".cache/"
import string
import time
from threading import Lock
import librosa
import numpy as np
import opencc
import torch
from faster_whisper import WhisperModel
t2s_converter = opencc.OpenCC("t2s")
def load_model(*, device="cuda"):
model = WhisperModel(
"medium",
device=device,
compute_type="float16",
download_root="faster_whisper",
)
print("faster_whisper loaded!")
return model
@torch.no_grad()
def batch_asr_internal(model: WhisperModel, audios, sr):
resampled_audios = []
for audio in audios:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if audio.dim() > 1:
audio = audio.squeeze()
assert audio.dim() == 1
audio_np = audio.numpy()
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
resampled_audios.append(resampled_audio)
trans_results = []
for resampled_audio in resampled_audios:
segments, info = model.transcribe(
resampled_audio,
language=None,
beam_size=5,
initial_prompt="Punctuation is needed in any language.",
)
trans_results.append(list(segments))
results = []
for trans_res, audio in zip(trans_results, audios):
duration = len(audio) / sr * 1000
huge_gap = False
max_gap = 0.0
text = None
last_tr = None
for tr in trans_res:
delta = tr.text.strip()
if tr.id > 1:
max_gap = max(tr.start - last_tr.end, max_gap)
text += delta
else:
text = delta
last_tr = tr
if max_gap > 3.0:
huge_gap = True
break
sim_text = t2s_converter.convert(text)
results.append(
{
"text": sim_text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results
global_lock = Lock()
def batch_asr(model, audios, sr):
return batch_asr_internal(model, audios, sr)
def is_chinese(text):
return True
def calculate_wer(text1, text2, debug=False):
chars1 = remove_punctuation(text1)
chars2 = remove_punctuation(text2)
m, n = len(chars1), len(chars2)
if m > n:
chars1, chars2 = chars2, chars1
m, n = n, m
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
curr = [0] * (m + 1)
for j in range(1, n + 1):
curr[0] = j
for i in range(1, m + 1):
if chars1[i - 1] == chars2[j - 1]:
curr[i] = prev[i - 1]
else:
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
prev, curr = curr, prev
edits = prev[m]
tot = max(len(chars1), len(chars2))
wer = edits / tot
if debug:
print(" gt: ", chars1)
print(" pred: ", chars2)
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
return wer
def remove_punctuation(text):
chinese_punctuation = (
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
'‛""„‟…‧﹏'
)
all_punctuation = string.punctuation + chinese_punctuation
translator = str.maketrans("", "", all_punctuation)
text_without_punctuation = text.translate(translator)
return text_without_punctuation
if __name__ == "__main__":
model = load_model()
audios = [
librosa.load("44100.wav", sr=44100)[0],
librosa.load("lengyue.wav", sr=44100)[0],
]
print(np.array(audios[0]))
print(batch_asr(model, audios, 44100))
start_time = time.time()
for _ in range(10):
print(batch_asr(model, audios, 44100))
print("Time taken:", time.time() - start_time)