aliosha commited on
Commit
511ffb7
1 Parent(s): 928afe5

addin small entrypoint

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -1,22 +1,25 @@
1
  import os
2
  import whisper
3
- from flask import Flask, jsonify, render_template, request, send_file
4
  from werkzeug.utils import secure_filename
5
 
6
  STARTING_SIZE = 'small'
7
  UPLOAD_FOLDER = 'uploads'
8
  ALLOWED_EXTENSIONS = {'ogg', 'mp3', 'mp4', 'wav',
9
  'flac', 'm4a', 'aac', 'wma', 'webm', 'opus'}
10
- current_size = STARTING_SIZE
 
11
 
12
  app = Flask(__name__)
13
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
14
 
15
- model = whisper.load_model(current_size)
16
- model_en = whisper.load_model(f"{current_size}.en")
 
 
17
 
18
 
19
- def inference(audio_file):
20
  audio = whisper.load_audio(audio_file)
21
  audio = whisper.pad_or_trim(audio)
22
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
@@ -62,5 +65,21 @@ def index():
62
  return "nothing yet to see here"
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if __name__ == "__main__":
66
  app.run(host="0.0.0.0", port=7860)
 
1
  import os
2
  import whisper
3
+ from flask import Flask, jsonify, request
4
  from werkzeug.utils import secure_filename
5
 
6
  STARTING_SIZE = 'small'
7
  UPLOAD_FOLDER = 'uploads'
8
  ALLOWED_EXTENSIONS = {'ogg', 'mp3', 'mp4', 'wav',
9
  'flac', 'm4a', 'aac', 'wma', 'webm', 'opus'}
10
+ normal_size = STARTING_SIZE
11
+ small_size = 'base'
12
 
13
  app = Flask(__name__)
14
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
15
 
16
+ model = whisper.load_model(normal_size)
17
+ model_en = whisper.load_model(f"{normal_size}.en")
18
+ model_small = whisper.load_model(small_size)
19
+ model_small_en = whisper.load_model(f"{small_size}.en")
20
 
21
 
22
+ def inference(audio_file, model=model, model_en=model_en):
23
  audio = whisper.load_audio(audio_file)
24
  audio = whisper.pad_or_trim(audio)
25
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
 
65
  return "nothing yet to see here"
66
 
67
 
68
+ @app.route("/small", methods=['POST'])
69
+ def index():
70
+ if 'file' not in request.files:
71
+ return "no file sent!"
72
+ uploaded_file = request.files['file']
73
+ if uploaded_file.filename == '':
74
+ return "no file sent!"
75
+ if uploaded_file and allowed_file(uploaded_file.filename):
76
+ filename = secure_filename(uploaded_file.filename)
77
+ uploaded_file.save(os.path.join(
78
+ app.config['UPLOAD_FOLDER'], filename))
79
+ results = inference(os.path.join(
80
+ app.config['UPLOAD_FOLDER'], filename), model_small, model_small_en)
81
+ return jsonify({"results": results})
82
+
83
+
84
  if __name__ == "__main__":
85
  app.run(host="0.0.0.0", port=7860)