from asyncio.log import logger import logging import os import whisper from flask import Flask, jsonify, request from werkzeug.utils import secure_filename try: import gunicorn.app.base gunicorn_present = True except ImportError: logging.exception("gunicorn not installed") gunicorn_present = False gunicorn_present = False STARTING_SIZE = 'small' UPLOAD_FOLDER = 'uploads' ALLOWED_EXTENSIONS = {'ogg', 'mp3', 'mp4', 'wav', 'flac', 'm4a', 'aac', 'wma', 'webm', 'opus'} normal_size = STARTING_SIZE small_size = 'base' PORT = 7860 app = Flask(__name__) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER model = whisper.load_model(normal_size) model_en = whisper.load_model(f"{normal_size}.en") model_small = whisper.load_model(small_size) model_small_en = whisper.load_model(f"{small_size}.en") if gunicorn_present: class StandaloneApplication(gunicorn.app.base.BaseApplication): def __init__(self, app, options=None): self.options = options or {} self.application = app super().__init__() def load_config(self): config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) def load(self): return self.application def inference(audio_file, model=model, model_en=model_en): audio = whisper.load_audio(audio_file) audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio).to(model.device) _, probs = model.detect_language(mel) lang = max(probs, key=probs.get) if lang == "en": _model = model_en else: _model = model result = _model.transcribe(audio_file, fp16=False, language=lang) segmented_text_list = [] text = result['text'] for segment in result["segments"]: segmented_text_list.append( # f'[{segment["start"]:.1f}-{segment["end"]:.1f}] {segment["text"]}') f'{segment["start"]:.1f}s: {segment["text"]}') segmented_text = "\n".join(segmented_text_list) return segmented_text, text def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route("/", methods=['GET', 'POST']) def index(): if request.method == 'POST': # check if the post request has the file part if 'file' not in request.files: return "no file sent!" uploaded_file = request.files['file'] if uploaded_file.filename == '': return "no file sent!" if uploaded_file and allowed_file(uploaded_file.filename): filename = secure_filename(uploaded_file.filename) uploaded_file.save(os.path.join( app.config['UPLOAD_FOLDER'], filename)) timed_transcription, transcription = inference(os.path.join( app.config['UPLOAD_FOLDER'], filename)) return jsonify({"results": timed_transcription, "text": transcription}) return "nothing yet to see here" @app.route("/small", methods=['POST']) def small(): if 'file' not in request.files: return "no file sent!" uploaded_file = request.files['file'] if uploaded_file.filename == '': return "no file sent!" if uploaded_file and allowed_file(uploaded_file.filename): filename = secure_filename(uploaded_file.filename) uploaded_file.save(os.path.join( app.config['UPLOAD_FOLDER'], filename)) timed_transcription, transcription = inference(os.path.join( app.config['UPLOAD_FOLDER'], filename), model_small, model_small_en) return jsonify({"results": timed_transcription, "text": transcription}) if __name__ == "__main__": options = { 'bind': f'0.0.0.0:{PORT}', # 'workers': 4, 'timeout': 600, } try: StandaloneApplication(app, options).run() except: logging.exception("Error starting server, backing off to debug mode") app.run(host="0.0.0.0", port=PORT)