import gradio as gr import os import tempfile import subprocess from transformers import pipeline import torch from zipfile import ZipFile from fastapi import FastAPI app = FastAPI() device = "cuda:0" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): model_id = "openai/whisper-small.en" else: model_id = "openai/whisper-tiny.en" pipe = pipeline( "automatic-speech-recognition", model=model_id, chunk_length_s=30, device=device, ) def support_gbk(zip_file: ZipFile): name_to_info = zip_file.NameToInfo # copy map first for name, info in name_to_info.copy().items(): real_name = name.encode("cp437").decode("gbk") if real_name != name: info.filename = real_name del name_to_info[name] name_to_info[real_name] = info return zip_file def handel(f): if not f: raise gr.Error("请上传文件") if f.name.endswith(".zip"): with support_gbk(ZipFile(f.name, "r")) as z: dir = tempfile.TemporaryDirectory() z.extractall(path=dir.name) return handel_files( [ os.path.join(filepath, filename) for filepath, _, filenames in os.walk(dir.name) for filename in filenames ] ) else: return handel_files([f.name]) def ffmpeg_convert(file_input, file_output): if subprocess.run(["ffmpeg", "-y", "-i", file_input, file_output]).returncode: raise gr.Error("ffmpeg_convert 失败, 请检查文件格式是否正确") def handel_files(f_ls): files = [] for file in f_ls: file_output=None if file.endswith(".m4a"): file_output = file.replace(".m4a", ".wav") ffmpeg_convert(file, file_output) elif file.endswith(".mp3"): file_output = file.replace(".mp3", ".wav") ffmpeg_convert(file, file_output) elif file.endswith(".wav"): # check wav file is valid or not file_output = file+".wav" ffmpeg_convert(file, file_output) if file_output: files.append(file_output) else: gr.Warning(f"存在不合法文件{os.path.basename(file)},已跳过处理") ret = [] for file in files: ret.append(whisper_handler(file)) return "\n\n".join(ret) def whisper_handler(file): file_name = os.path.basename(file) gr.Info(f"处理文件 - {file_name.split('.')[0]}") return pipe(file)["text"] with gr.Blocks() as blocks: f = gr.File(file_types=[".zip", ".mp3", ".wav", ".m4a"]) b = gr.Button(value="提交") t = gr.Textbox(label="结果") b.click(handel, inputs=f, outputs=t) blocks.queue(max_size=3) app = gr.mount_gradio_app(app, blocks, path="/")