whisper.cpp is much slower than whisper jax?

#12
by zhimakaimenhf - opened

I build whisper.cpp with following command line:

sudo WHISPER_CUBLAS=1 WHISPER_CLBLAST=1 make -j8

Then I run the following line:./main -m models/ggml-base.bin -f samples/jfk.wav

got result:

whisper_init_from_file_no_state: loading model from 'models/ggml-base.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51865
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512
whisper_model_load: n_audio_head = 8
whisper_model_load: n_audio_layer = 6
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 512
whisper_model_load: n_text_head = 8
whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: ftype = 1
whisper_model_load: qntvr = 0
whisper_model_load: type = 2
whisper_model_load: adding 1608 extra tokens
whisper_model_load: model ctx = 140.66 MB
ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6
whisper_model_load: model size = 140.54 MB
whisper_init_state: kv self size = 5.25 MB
whisper_init_state: kv cross size = 17.58 MB
whisper_init_state: compute buffer (conv) = 14.10 MB
whisper_init_state: compute buffer (encode) = 81.85 MB
whisper_init_state: compute buffer (cross) = 4.40 MB
whisper_init_state: compute buffer (decode) = 24.61 MB

system_info: n_threads = 4 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | METAL = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | COREML = 0 | OPENVINO = 0 |

main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...

[00:00:00.000 --> 00:00:07.600] And so my fellow Americans, ask not what your country can do for you,
[00:00:07.600 --> 00:00:10.600] ask what you can do for your country.

whisper_print_timings: load time = 856.69 ms
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: mel time = 9.35 ms
whisper_print_timings: sample time = 9.88 ms / 29 runs ( 0.34 ms per run)
whisper_print_timings: encode time = 361.23 ms / 1 runs ( 361.23 ms per run)
whisper_print_timings: decode time = 138.41 ms / 28 runs ( 4.94 ms per run)
whisper_print_timings: prompt time = 6.21 ms / 1 runs ( 6.21 ms per run)
whisper_print_timings: total time = 1426.51 ms

on python project whisper jax, same model, same wave file, only takes 0.137s

whisper jax works on jax, but only 1 3060 there, it can't earn a lot from pmap(parallel computing)

Any idea?

Sign up or log in to comment