xc_expt / asr_transcript.py
isuneast's picture
init
37ab3e3
import argparse
import concurrent.futures
import os
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from tqdm import tqdm
from tools.log import logger
os.environ["MODELSCOPE_CACHE"] = "./"
def transcribe_worker(file_path: str, inference_pipeline, language):
"""
Worker function for transcribing a segment of an audio file.
"""
lab_path = os.path.splitext(file_path)[0] + '.lab'
if os.path.exists(lab_path) and os.path.isfile(lab_path):
logger.info(f'{lab_path}为已转写的文本,跳过~')
with open(lab_path, 'r', encoding='utf-8') as f:
text = f.read()
return text
rec_result = inference_pipeline(audio_in=file_path)
text = str(rec_result.get("text", "")).strip()
text_without_spaces = text.replace(" ", "")
logger.info(file_path)
if language != "EN":
logger.info("text: " + text_without_spaces)
return text_without_spaces
else:
logger.info("text: " + text)
return text
def transcribe_folder_parallel(folder_path, language, max_workers=4):
"""
Transcribe all .wav files in the given folder using ThreadPoolExecutor.
"""
logger.info(f"parallel transcribe: {folder_path}|{language}|{max_workers}")
if language == "JP":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline",
)
for _ in range(max_workers)
]
elif language == "ZH":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v1.2.4",
)
for _ in range(max_workers)
]
else:
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
)
for _ in range(max_workers)
]
file_paths = []
langs = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.lower().endswith(".wav"):
file_path = os.path.join(root, file)
file_paths.append(file_path)
langs.append(language)
all_workers = (
workers * (len(file_paths) // max_workers)
+ workers[: len(file_paths) % max_workers]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for i in tqdm(range(0, len(file_paths), max_workers), desc="转写进度: "):
l, r = i, min(i + max_workers, len(file_paths))
transcriptions = list(
executor.map(
transcribe_worker, file_paths[l:r], all_workers[l:r], langs[l:r]
)
)
for file_path, transcription in zip(file_paths[l:r], transcriptions):
if transcription:
lab_file_path = os.path.splitext(file_path)[0] + ".lab"
with open(lab_file_path, "w", encoding="utf-8") as lab_file:
lab_file.write(transcription)
logger.info("已经将wav文件转写为同名的.lab文件")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--filepath", default="./raw/lzy_zh", help="path of your model"
)
parser.add_argument("-l", "--language", default="ZH", help="language")
parser.add_argument("-w", "--workers", default="1", help="trans workers")
args = parser.parse_args()
transcribe_folder_parallel(args.filepath, args.language, int(args.workers))
print("转写结束!")