transcribe_api / app.py
aliosha's picture
going back to working version
5cb25c7
from asyncio.log import logger
import logging
import os
import whisper
from flask import Flask, jsonify, request
from werkzeug.utils import secure_filename
try:
import gunicorn.app.base
gunicorn_present = True
except ImportError:
logging.exception("gunicorn not installed")
gunicorn_present = False
gunicorn_present = False
STARTING_SIZE = 'small'
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'ogg', 'mp3', 'mp4', 'wav',
'flac', 'm4a', 'aac', 'wma', 'webm', 'opus'}
normal_size = STARTING_SIZE
small_size = 'base'
PORT = 7860
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
model = whisper.load_model(normal_size)
model_en = whisper.load_model(f"{normal_size}.en")
model_small = whisper.load_model(small_size)
model_small_en = whisper.load_model(f"{small_size}.en")
if gunicorn_present:
class StandaloneApplication(gunicorn.app.base.BaseApplication):
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
config = {key: value for key, value in self.options.items()
if key in self.cfg.settings and value is not None}
for key, value in config.items():
self.cfg.set(key.lower(), value)
def load(self):
return self.application
def inference(audio_file, model=model, model_en=model_en):
audio = whisper.load_audio(audio_file)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)
lang = max(probs, key=probs.get)
if lang == "en":
_model = model_en
else:
_model = model
result = _model.transcribe(audio_file, fp16=False, language=lang)
segmented_text_list = []
text = result['text']
for segment in result["segments"]:
segmented_text_list.append(
# f'[{segment["start"]:.1f}-{segment["end"]:.1f}] {segment["text"]}')
f'{segment["start"]:.1f}s: {segment["text"]}')
segmented_text = "\n".join(segmented_text_list)
return segmented_text, text
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route("/", methods=['GET', 'POST'])
def index():
if request.method == 'POST':
# check if the post request has the file part
if 'file' not in request.files:
return "no file sent!"
uploaded_file = request.files['file']
if uploaded_file.filename == '':
return "no file sent!"
if uploaded_file and allowed_file(uploaded_file.filename):
filename = secure_filename(uploaded_file.filename)
uploaded_file.save(os.path.join(
app.config['UPLOAD_FOLDER'], filename))
timed_transcription, transcription = inference(os.path.join(
app.config['UPLOAD_FOLDER'], filename))
return jsonify({"results": timed_transcription, "text": transcription})
return "nothing yet to see here"
@app.route("/small", methods=['POST'])
def small():
if 'file' not in request.files:
return "no file sent!"
uploaded_file = request.files['file']
if uploaded_file.filename == '':
return "no file sent!"
if uploaded_file and allowed_file(uploaded_file.filename):
filename = secure_filename(uploaded_file.filename)
uploaded_file.save(os.path.join(
app.config['UPLOAD_FOLDER'], filename))
timed_transcription, transcription = inference(os.path.join(
app.config['UPLOAD_FOLDER'], filename), model_small, model_small_en)
return jsonify({"results": timed_transcription, "text": transcription})
if __name__ == "__main__":
options = {
'bind': f'0.0.0.0:{PORT}',
# 'workers': 4,
'timeout': 600,
}
try:
StandaloneApplication(app, options).run()
except:
logging.exception("Error starting server, backing off to debug mode")
app.run(host="0.0.0.0", port=PORT)