cyberspyde commited on
Commit
c858f8e
1 Parent(s): 06eaec7
Files changed (1) hide show
  1. main.py +41 -11
main.py CHANGED
@@ -1,27 +1,57 @@
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
3
- from transformers import pipeline
4
  import numpy as np
5
- import json
 
6
  app = Flask(__name__)
7
  model = AutoModelForSpeechSeq2Seq.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
8
  processor = AutoProcessor.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @app.route('/', methods=['GET'])
11
  def index():
12
  return jsonify({"message": "Welcome to whisper uz!"})
13
 
14
  @app.route('/transcribe', methods=['POST'])
15
  def transcribe():
16
- data = request.json['data']
17
- data = json.loads(data)
18
- tensor_data = np.array(data)
19
- inputs = processor(tensor_data, return_tensors="pt", sampling_rate=16000, max_new_tokens=100)
20
- input_features = inputs.input_features
21
- generated_ids = model.generate(inputs=input_features)
22
-
23
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
24
- transcription = ''.join(transcription)
 
 
 
 
 
 
25
  return str(transcription), {'Content-Type': 'application/json'}
26
 
27
  if __name__ == '__main__':
 
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
3
  import numpy as np
4
+ import torch
5
+
6
  app = Flask(__name__)
7
  model = AutoModelForSpeechSeq2Seq.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
8
  processor = AutoProcessor.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
9
 
10
+
11
+ USE_ONNX = False # change this to True if you want to test onnx model
12
+ silero_vad_path = 'snakers4/silero-vad'
13
+ vad_model, vad_utils = torch.hub.load(silero_vad_path,
14
+ model='silero_vad',
15
+ force_reload=True,
16
+ onnx=USE_ONNX)
17
+
18
+ (get_speech_timestamps,
19
+ save_audio,
20
+ read_audio,
21
+ VADIterator,
22
+ collect_chunks) = vad_utils
23
+ STT_SAMPLE_RATE = 16000
24
+
25
+
26
+ def int2float(sound):
27
+ abs_max = np.abs(sound).max()
28
+ sound = sound.astype('float32')
29
+ if abs_max > 0:
30
+ sound *= 1/32768
31
+ sound = sound.squeeze() # depends on the use case
32
+ return sound
33
+
34
  @app.route('/', methods=['GET'])
35
  def index():
36
  return jsonify({"message": "Welcome to whisper uz!"})
37
 
38
  @app.route('/transcribe', methods=['POST'])
39
  def transcribe():
40
+ data_frames = request.data
41
+ audio_data = np.frombuffer(data_frames, dtype=np.int16)
42
+ audio_float = int2float(audio_data)
43
+ final_data = torch.from_numpy(audio_float)
44
+ sp_timestamps = get_speech_timestamps(final_data, vad_model, sampling_rate=STT_SAMPLE_RATE)
45
+ try:
46
+ final_audio_data = collect_chunks(sp_timestamps, final_data)
47
+ inputs = processor(final_audio_data, return_tensors="pt", sampling_rate=16000, max_new_tokens=100)
48
+ input_features = inputs.input_features
49
+ generated_ids = model.generate(inputs=input_features)
50
+
51
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
52
+ transcription = ''.join(transcription)
53
+ except Exception as e:
54
+ transcription = str(e)
55
  return str(transcription), {'Content-Type': 'application/json'}
56
 
57
  if __name__ == '__main__':