Motivation behind the development of optimised JAX code for OpenAI's Whisper Model

#16
by architsingh - opened

Hello Sanchit, wanted to express my appreciation for the work you've done on this implementation of OpenAI's Whisper Model. It's clear that you've put in a lot of effort to create such an outstanding solution.

had some questions if you could take the time to answer them:

  1. What motivated the development of optimized JAX code for OpenAI's Whisper Model, and what are this implementation's key features and benefits in terms of speed, hardware compatibility, and usage?

  2. What specific techniques or optimizations were used to achieve over 70x faster performance with Whisper JAX compared to OpenAI's PyTorch code, and how do they contribute to the overall speed and efficiency of the implementation?

  3. How can the standalone and endpoint modes of Whisper JAX be used in different applications, and what are some real-world examples of using this implementation for audio transcription or other speech-related tasks?

Hey @architsingh ! Thanks for your interest in Whisper JAX, answering your questions below:

  1. We wanted to democratise ASR for the open-source community: OpenAI released their Whisper API a few months ago: they claim to be faster than any other ASR API service. We wanted to make Whisper JAX faster than theirs in order to return this accolade to an open-source implementation 🏆 This Twitter thread explains the key features on a high-level: https://twitter.com/sanchitgandhi99/status/1649046650793648128 There's more detail on the README: https://github.com/sanchit-gandhi/whisper-jax It's compatible on all devices that support JAX: https://github.com/google/jax#installation

  2. Again, details in the Twitter thread!

  3. See section on Creating an Endpoint for details on launching your own demo and pinging the endpoint. See https://developer.nvidia.com/blog/exploring-unique-applications-of-automatic-speech-recognition-technology/ for applications.

Sign up or log in to comment