import os os.system("pip install git+https://github.com/sanchit-gandhi/whisper-jax.git") # import whisper from flask import Flask, jsonify, request import requests import time # from transformers import pipeline from whisper_jax import FlaxWhisperPipline import jax.numpy as jnp # model = whisper.load_model("small") # pipe = pipeline( # "automatic-speech-recognition", # model="openai/whisper-small.en", # chunk_length_s=15, # device=model.device, # ) pipe = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16, batch_size=16) app = Flask(__name__) app.config['TIMEOUT'] = 60 * 10 # 10 mins @app.route("/") def indexApi(): return jsonify({"output": "okay"}) @app.route("/run", methods=['POST']) def runApi(): start_time = time.time() audio_url = request.form.get("audio_url") response = requests.get(audio_url) if response.status_code == requests.codes.ok: with open("audio.mp3", "wb") as f: f.write(response.content) else: return jsonify({ "result": "Unable to save file, status code: {response.status_code}" , }), 400 audio = "audio.mp3" # audioOri = whisper.load_audio(audio) # audio = whisper.pad_or_trim(audioOri) # mel = whisper.log_mel_spectrogram(audio).to(model.device) # _, probs = model.detect_language(mel) # options = whisper.DecodingOptions(fp16 = False) # result = whisper.decode(model, mel, options) # test 2 # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # sample = ds[0]["audio"] prediction = pipe(audio, task="transcribe")["text"] end_time = time.time() total_time = end_time - start_time return jsonify({ "audio_url": audio_url, "result": prediction, "exec_time_sec": total_time }) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860) # def inference(audio): # audio = whisper.load_audio(audio) # audio = whisper.pad_or_trim(audio) # mel = whisper.log_mel_spectrogram(audio).to(model.device) # _, probs = model.detect_language(mel) # options = whisper.DecodingOptions(fp16 = False) # result = whisper.decode(model, mel, options) # # print(result.text) # return result.text, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)