Whisper / python /whisper_svr.py
inoryQwQ
Update README, Update python API
ca02ffa
import argparse
import json
import os
import tempfile
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs
from whisper import Whisper
import cgi
# 模型缓存:避免每次请求都重新加载
_model_cache = {}
def get_model(model_type, model_path, language, task):
key = (model_type, model_path, language, task)
if key not in _model_cache:
print(f"Loading model: type={model_type}, path={model_path}, lang={language}, task={task}")
_model_cache[key] = Whisper(model_type, model_path, language, task)
return _model_cache[key]
class WhisperHandler(BaseHTTPRequestHandler):
def _send_json(self, obj, status=200):
data = json.dumps(obj, ensure_ascii=False).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_GET(self):
if self.path == "/health":
self._send_json({"status": "ok"})
else:
self._send_json({"error": "not found"}, 404)
def do_POST(self):
if self.path != "/asr":
self._send_json({"error": "not found"}, 404)
return
# 解析 multipart/form-data
content_type = self.headers.get('Content-Type')
if not content_type:
self._send_json({"error": "Missing Content-Type"}, 400)
return
ctype, pdict = cgi.parse_header(content_type)
if ctype != 'multipart/form-data':
self._send_json({"error": "Only multipart/form-data is supported"}, 400)
return
pdict['boundary'] = bytes(pdict['boundary'], "utf-8")
pdict['CONTENT-LENGTH'] = int(self.headers['Content-Length'])
form = cgi.parse_multipart(self.rfile, pdict)
# 必须包含 wav 文件
if "wav" not in form:
self._send_json({"error": "Field 'wav' is required"}, 400)
return
# 获取参数(如果缺省则使用默认值)
model_type = form.get("model_type", ["tiny"])[0]
model_path = form.get("model_path", ["../models/models-ax650"])[0]
language = form.get("language", ["zh"])[0]
task = form.get("task", ["transcribe"])[0]
if task not in ("transcribe", "translate"):
self._send_json({"error": "task must be 'transcribe' or 'translate'"}, 400)
return
wav_bytes = form["wav"][0]
# 写入临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(wav_bytes)
wav_path = tmp.name
# 加载模型并运行
try:
model = get_model(model_type, model_path, language, task)
result_text = model.run(wav_path)
except Exception as e:
self._send_json({"error": str(e)}, 500)
return
finally:
if os.path.exists(wav_path):
os.remove(wav_path)
self._send_json({"text": result_text})
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Whisper Server")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
port = args.port
server = HTTPServer(("0.0.0.0", port), WhisperHandler)
print(f"Server started at http://0.0.0.0:{port}")
server.serve_forever()