|
import argparse |
|
import concurrent.futures |
|
import os |
|
|
|
from loguru import logger |
|
from modelscope.pipelines import pipeline |
|
from modelscope.utils.constant import Tasks |
|
from tqdm import tqdm |
|
|
|
os.environ["MODELSCOPE_CACHE"] = "./" |
|
|
|
|
|
def transcribe_worker(file_path: str, inference_pipeline, language): |
|
""" |
|
Worker function for transcribing a segment of an audio file. |
|
""" |
|
rec_result = inference_pipeline(audio_in=file_path) |
|
text = str(rec_result.get("text", "")).strip() |
|
text_without_spaces = text.replace(" ", "") |
|
logger.info(file_path) |
|
if language != "EN": |
|
logger.info("text: " + text_without_spaces) |
|
return text_without_spaces |
|
else: |
|
logger.info("text: " + text) |
|
return text |
|
|
|
|
|
def transcribe_folder_parallel(folder_path, language, max_workers=4): |
|
""" |
|
Transcribe all .wav files in the given folder using ThreadPoolExecutor. |
|
""" |
|
logger.critical(f"parallel transcribe: {folder_path}|{language}|{max_workers}") |
|
if language == "JP": |
|
workers = [ |
|
pipeline( |
|
task=Tasks.auto_speech_recognition, |
|
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline", |
|
) |
|
for _ in range(max_workers) |
|
] |
|
|
|
elif language == "ZH": |
|
workers = [ |
|
pipeline( |
|
task=Tasks.auto_speech_recognition, |
|
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
|
model_revision="v1.2.4", |
|
) |
|
for _ in range(max_workers) |
|
] |
|
else: |
|
workers = [ |
|
pipeline( |
|
task=Tasks.auto_speech_recognition, |
|
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline", |
|
) |
|
for _ in range(max_workers) |
|
] |
|
|
|
file_paths = [] |
|
langs = [] |
|
for root, _, files in os.walk(folder_path): |
|
for file in files: |
|
if file.lower().endswith(".wav"): |
|
file_path = os.path.join(root, file) |
|
lab_file_path = os.path.splitext(file_path)[0] + ".lab" |
|
file_paths.append(file_path) |
|
langs.append(language) |
|
|
|
all_workers = ( |
|
workers * (len(file_paths) // max_workers) |
|
+ workers[: len(file_paths) % max_workers] |
|
) |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
for i in tqdm(range(0, len(file_paths), max_workers), desc="转写进度: "): |
|
l, r = i, min(i + max_workers, len(file_paths)) |
|
transcriptions = list( |
|
executor.map( |
|
transcribe_worker, file_paths[l:r], all_workers[l:r], langs[l:r] |
|
) |
|
) |
|
for file_path, transcription in zip(file_paths[l:r], transcriptions): |
|
if transcription: |
|
lab_file_path = os.path.splitext(file_path)[0] + ".lab" |
|
with open(lab_file_path, "w", encoding="utf-8") as lab_file: |
|
lab_file.write(transcription) |
|
logger.critical("已经将wav文件转写为同名的.lab文件") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-f", "--filepath", default="./raw/lzy_zh", help="path of your model" |
|
) |
|
parser.add_argument("-l", "--language", default="ZH", help="language") |
|
parser.add_argument("-w", "--workers", default="1", help="trans workers") |
|
args = parser.parse_args() |
|
|
|
transcribe_folder_parallel(args.filepath, args.language, int(args.workers)) |
|
print("转写结束!") |
|
|