kn / asr_transcript.py
lafi23333's picture
Upload 191 files
9b1761d
raw history blame
No virus
3.65 kB
import argparse
import concurrent.futures
import os
from loguru import logger
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from tqdm import tqdm
os.environ["MODELSCOPE_CACHE"] = "./"
def transcribe_worker(file_path: str, inference_pipeline, language):
"""
Worker function for transcribing a segment of an audio file.
"""
rec_result = inference_pipeline(audio_in=file_path)
text = str(rec_result.get("text", "")).strip()
text_without_spaces = text.replace(" ", "")
logger.info(file_path)
if language != "EN":
logger.info("text: " + text_without_spaces)
return text_without_spaces
else:
logger.info("text: " + text)
return text
def transcribe_folder_parallel(folder_path, language, max_workers=4):
"""
Transcribe all .wav files in the given folder using ThreadPoolExecutor.
"""
logger.critical(f"parallel transcribe: {folder_path}|{language}|{max_workers}")
if language == "JP":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline",
)
for _ in range(max_workers)
]
elif language == "ZH":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v1.2.4",
)
for _ in range(max_workers)
]
else:
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
)
for _ in range(max_workers)
]
file_paths = []
langs = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.lower().endswith(".wav"):
file_path = os.path.join(root, file)
lab_file_path = os.path.splitext(file_path)[0] + ".lab"
file_paths.append(file_path)
langs.append(language)
all_workers = (
workers * (len(file_paths) // max_workers)
+ workers[: len(file_paths) % max_workers]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for i in tqdm(range(0, len(file_paths), max_workers), desc="转写进度: "):
l, r = i, min(i + max_workers, len(file_paths))
transcriptions = list(
executor.map(
transcribe_worker, file_paths[l:r], all_workers[l:r], langs[l:r]
)
)
for file_path, transcription in zip(file_paths[l:r], transcriptions):
if transcription:
lab_file_path = os.path.splitext(file_path)[0] + ".lab"
with open(lab_file_path, "w", encoding="utf-8") as lab_file:
lab_file.write(transcription)
logger.critical("已经将wav文件转写为同名的.lab文件")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--filepath", default="./raw/lzy_zh", help="path of your model"
)
parser.add_argument("-l", "--language", default="ZH", help="language")
parser.add_argument("-w", "--workers", default="1", help="trans workers")
args = parser.parse_args()
transcribe_folder_parallel(args.filepath, args.language, int(args.workers))
print("转写结束!")