Spaces:
Build error
Build error
from asyncio.log import logger | |
import logging | |
import os | |
import whisper | |
from flask import Flask, jsonify, request | |
from werkzeug.utils import secure_filename | |
try: | |
import gunicorn.app.base | |
gunicorn_present = True | |
except ImportError: | |
logging.exception("gunicorn not installed") | |
gunicorn_present = False | |
gunicorn_present = False | |
STARTING_SIZE = 'small' | |
UPLOAD_FOLDER = 'uploads' | |
ALLOWED_EXTENSIONS = {'ogg', 'mp3', 'mp4', 'wav', | |
'flac', 'm4a', 'aac', 'wma', 'webm', 'opus'} | |
normal_size = STARTING_SIZE | |
small_size = 'base' | |
PORT = 7860 | |
app = Flask(__name__) | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
model = whisper.load_model(normal_size) | |
model_en = whisper.load_model(f"{normal_size}.en") | |
model_small = whisper.load_model(small_size) | |
model_small_en = whisper.load_model(f"{small_size}.en") | |
if gunicorn_present: | |
class StandaloneApplication(gunicorn.app.base.BaseApplication): | |
def __init__(self, app, options=None): | |
self.options = options or {} | |
self.application = app | |
super().__init__() | |
def load_config(self): | |
config = {key: value for key, value in self.options.items() | |
if key in self.cfg.settings and value is not None} | |
for key, value in config.items(): | |
self.cfg.set(key.lower(), value) | |
def load(self): | |
return self.application | |
def inference(audio_file, model=model, model_en=model_en): | |
audio = whisper.load_audio(audio_file) | |
audio = whisper.pad_or_trim(audio) | |
mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
_, probs = model.detect_language(mel) | |
lang = max(probs, key=probs.get) | |
if lang == "en": | |
_model = model_en | |
else: | |
_model = model | |
result = _model.transcribe(audio_file, fp16=False, language=lang) | |
segmented_text_list = [] | |
text = result['text'] | |
for segment in result["segments"]: | |
segmented_text_list.append( | |
# f'[{segment["start"]:.1f}-{segment["end"]:.1f}] {segment["text"]}') | |
f'{segment["start"]:.1f}s: {segment["text"]}') | |
segmented_text = "\n".join(segmented_text_list) | |
return segmented_text, text | |
def allowed_file(filename): | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def index(): | |
if request.method == 'POST': | |
# check if the post request has the file part | |
if 'file' not in request.files: | |
return "no file sent!" | |
uploaded_file = request.files['file'] | |
if uploaded_file.filename == '': | |
return "no file sent!" | |
if uploaded_file and allowed_file(uploaded_file.filename): | |
filename = secure_filename(uploaded_file.filename) | |
uploaded_file.save(os.path.join( | |
app.config['UPLOAD_FOLDER'], filename)) | |
timed_transcription, transcription = inference(os.path.join( | |
app.config['UPLOAD_FOLDER'], filename)) | |
return jsonify({"results": timed_transcription, "text": transcription}) | |
return "nothing yet to see here" | |
def small(): | |
if 'file' not in request.files: | |
return "no file sent!" | |
uploaded_file = request.files['file'] | |
if uploaded_file.filename == '': | |
return "no file sent!" | |
if uploaded_file and allowed_file(uploaded_file.filename): | |
filename = secure_filename(uploaded_file.filename) | |
uploaded_file.save(os.path.join( | |
app.config['UPLOAD_FOLDER'], filename)) | |
timed_transcription, transcription = inference(os.path.join( | |
app.config['UPLOAD_FOLDER'], filename), model_small, model_small_en) | |
return jsonify({"results": timed_transcription, "text": transcription}) | |
if __name__ == "__main__": | |
options = { | |
'bind': f'0.0.0.0:{PORT}', | |
# 'workers': 4, | |
'timeout': 600, | |
} | |
try: | |
StandaloneApplication(app, options).run() | |
except: | |
logging.exception("Error starting server, backing off to debug mode") | |
app.run(host="0.0.0.0", port=PORT) | |