defmodule MedicalTranscription.Audio.TranscriptionFilter do @moduledoc """ A Membrane filter to transcribe audio with the Whisper model. Based on https://github.com/membraneframework/membrane_demo/blob/master/speech_to_text/speech_to_text.livemd """ use Membrane.Filter alias Membrane.RawAudio require Membrane.Logger @vad_chunk_duration Membrane.Time.milliseconds(500) def_input_pad(:input, accepted_format: %RawAudio{sample_format: :f32le, channels: 1, sample_rate: 16_000} ) def_output_pad(:output, accepted_format: Membrane.RemoteStream) def_options( chunk_duration: [ spec: Membrane.Time.t(), default: Membrane.Time.seconds(2), default_inspector: &Membrane.Time.pretty_duration/1, description: """ The duration of chunks feeding the model. Must be at least 5 seconds. The longer the chunks, the better transcription accuracy, but bigger latency. """ ], vad_threshold: [ spec: float, default: 0.03, description: """ Volume threshold below which the input is considered to be silence. Used for optimizing aligment of chunks provided to the model and filtering out the silence to prevent hallucinations. """ ] ) @impl true def handle_setup(_ctx, options) do Membrane.Logger.info("Whisper model ready") state = Map.merge(options, %{ speech: <<>>, queue: <<>>, chunk_size: nil, vad_chunk_size: nil }) {[], state} end @impl true def handle_stream_format(:input, stream_format, _ctx, state) do state = %{ state | chunk_size: RawAudio.time_to_bytes(state.chunk_duration, stream_format), vad_chunk_size: RawAudio.time_to_bytes(@vad_chunk_duration, stream_format) } {[stream_format: {:output, %Membrane.RemoteStream{}}], state} end @impl true def handle_buffer(:input, buffer, _ctx, state) do input = state.queue <> buffer.payload if byte_size(input) > state.vad_chunk_size do process_data(input, %{state | queue: <<>>}) else {[], %{state | queue: input}} end end defp process_data(data, state) do # Here we filter out the silence at the beginning of each chunk. # This way we can fit as much speech in a single chunk as possible # and potentially remove whole silent chunks, which cause # model hallucinations. If after removing the silence the chunk # is not empty but too small to process, we store it in the state # and prepend it to the subsequent chunk. speech = if state.speech == <<>> do MedicalTranscription.Audio.Utilities.filter_silence(data, state) else state.speech <> data end if byte_size(speech) < state.chunk_size do {[], %{state | speech: speech}} else model_input = Nx.from_binary(speech, :f32) result = Nx.Serving.batched_run(MedicalTranscription.TranscriptionServing, model_input) |> Enum.into([]) transcription = Enum.map_join(result, & &1.text) buffer = %Membrane.Buffer{payload: transcription} {[buffer: {:output, buffer}], %{state | speech: <<>>}} end end end