huedaya commited on
Commit
03d9e79
1 Parent(s): 2c0de77
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -7,6 +7,8 @@ import requests
7
  import time
8
  # from transformers import pipeline
9
  from whisper_jax import FlaxWhisperPipline
 
 
10
 
11
 
12
  # model = whisper.load_model("small")
@@ -16,7 +18,7 @@ from whisper_jax import FlaxWhisperPipline
16
  # chunk_length_s=15,
17
  # device=model.device,
18
  # )
19
- pipe = FlaxWhisperPipline("openai/whisper-small")
20
 
21
  app = Flask(__name__)
22
  app.config['TIMEOUT'] = 60 * 10 # 10 mins
@@ -57,7 +59,7 @@ def runApi():
57
  # test 2
58
  # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
59
  # sample = ds[0]["audio"]
60
- prediction = pipe(audio)["text"]
61
 
62
 
63
  end_time = time.time()
 
7
  import time
8
  # from transformers import pipeline
9
  from whisper_jax import FlaxWhisperPipline
10
+ import jax.numpy as jnp
11
+
12
 
13
 
14
  # model = whisper.load_model("small")
 
18
  # chunk_length_s=15,
19
  # device=model.device,
20
  # )
21
+ pipe = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16, batch_size=16)
22
 
23
  app = Flask(__name__)
24
  app.config['TIMEOUT'] = 60 * 10 # 10 mins
 
59
  # test 2
60
  # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
61
  # sample = ds[0]["audio"]
62
+ prediction = pipe(audio, task="transcribe")["text"]
63
 
64
 
65
  end_time = time.time()