|
import type { |
|
AutomaticSpeechRecognitionPipeline, |
|
CausalLMOutputWithPast, |
|
GPT2Tokenizer, |
|
LlamaForCausalLM, |
|
PreTrainedModel, |
|
StoppingCriteriaList, |
|
} from '@huggingface/transformers' |
|
import type { Device, DType } from '@xsai-transformers/shared/types' |
|
import type { GenerateOptions } from 'kokoro-js' |
|
import type { |
|
WorkerMessageEventError, |
|
WorkerMessageEventInfo, |
|
WorkerMessageEventOutput, |
|
WorkerMessageEventProgress, |
|
WorkerMessageEventSetVoiceResponse, |
|
WorkerMessageEventStatus, |
|
} from '../types/worker' |
|
|
|
import { |
|
|
|
AutoModel, |
|
|
|
AutoModelForCausalLM, |
|
|
|
AutoTokenizer, |
|
InterruptableStoppingCriteria, |
|
pipeline, |
|
|
|
|
|
Tensor, |
|
TextStreamer, |
|
} from '@huggingface/transformers' |
|
import { isWebGPUSupported } from 'gpuu/webgpu' |
|
import { KokoroTTS, TextSplitterStream } from 'kokoro-js' |
|
|
|
import { |
|
EXIT_THRESHOLD, |
|
INPUT_SAMPLE_RATE, |
|
MAX_BUFFER_DURATION, |
|
MAX_NUM_PREV_BUFFERS, |
|
MIN_SILENCE_DURATION_SAMPLES, |
|
MIN_SPEECH_DURATION_SAMPLES, |
|
SPEECH_PAD_SAMPLES, |
|
SPEECH_THRESHOLD, |
|
} from '../constants' |
|
|
|
interface Message { |
|
role: 'system' | 'user' | 'assistant' |
|
content: string |
|
} |
|
|
|
type Voices = GenerateOptions['voice'] |
|
export type PretrainedConfig = NonNullable<Parameters<typeof AutoModel.from_pretrained>[1]>['config'] |
|
|
|
const whisperDtypeMap: Record<Device, DType> = { |
|
webgpu: { |
|
encoder_model: 'fp32', |
|
decoder_model_merged: 'fp32', |
|
}, |
|
wasm: { |
|
encoder_model: 'fp32', |
|
decoder_model_merged: 'q8', |
|
}, |
|
} |
|
|
|
const model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX' |
|
let voice: Voices | undefined |
|
let silero_vad: PreTrainedModel |
|
let transcriber: AutomaticSpeechRecognitionPipeline |
|
let tts: KokoroTTS |
|
|
|
const SYSTEM_MESSAGE: Message = { |
|
role: 'system', |
|
content: |
|
'You\'re a helpful and conversational voice assistant. Keep your responses short, clear, and casual.', |
|
} |
|
let messages: Message[] = [SYSTEM_MESSAGE] |
|
let past_key_values_cache: any = null |
|
let stopping_criteria: InterruptableStoppingCriteria | null = null |
|
|
|
|
|
const BUFFER = new Float32Array(MAX_BUFFER_DURATION * INPUT_SAMPLE_RATE) |
|
let bufferPointer = 0 |
|
|
|
|
|
const sr = new Tensor('int64', [INPUT_SAMPLE_RATE], []) |
|
let state = new Tensor('float32', new Float32Array(2 * 1 * 128), [2, 1, 128]) |
|
|
|
|
|
let isRecording = false |
|
let isPlaying = false |
|
|
|
let tokenizer: GPT2Tokenizer |
|
let llm: LlamaForCausalLM |
|
|
|
const prevBuffers: Float32Array[] = [] |
|
|
|
export async function loadModels() { |
|
tts = await KokoroTTS.from_pretrained(model_id, { |
|
dtype: 'fp32', |
|
device: 'webgpu', |
|
}) |
|
|
|
const device = 'webgpu' |
|
globalThis.postMessage({ type: 'info', data: { message: `Using device: "${device}"` } } satisfies WorkerMessageEventInfo) |
|
globalThis.postMessage({ type: 'info', data: { message: 'Loading models...', duration: 'until_next' } } satisfies WorkerMessageEventInfo) |
|
|
|
|
|
silero_vad = await AutoModel.from_pretrained( |
|
'onnx-community/silero-vad', |
|
{ |
|
config: { model_type: 'custom' } as PretrainedConfig, |
|
dtype: 'fp32', |
|
progress_callback: progress => globalThis.postMessage({ type: 'progress', data: { message: progress } } satisfies WorkerMessageEventProgress), |
|
}, |
|
).catch((error: Error) => { |
|
globalThis.postMessage({ type: 'error', data: { error, message: error.message } } satisfies WorkerMessageEventError<Error>) |
|
throw error |
|
}) |
|
|
|
transcriber = await pipeline( |
|
'automatic-speech-recognition', |
|
'onnx-community/whisper-base', |
|
{ |
|
device, |
|
dtype: whisperDtypeMap[device as keyof typeof whisperDtypeMap], |
|
progress_callback: progress => globalThis.postMessage({ type: 'progress', data: { message: progress } } satisfies WorkerMessageEventProgress), |
|
}, |
|
).catch((error: Error) => { |
|
globalThis.postMessage({ type: 'error', data: { error, message: error.message } } satisfies WorkerMessageEventError<Error>) |
|
throw error |
|
}) |
|
|
|
await transcriber(new Float32Array(INPUT_SAMPLE_RATE)) |
|
|
|
llm = await AutoModelForCausalLM.from_pretrained( |
|
'HuggingFaceTB/SmolLM2-1.7B-Instruct', |
|
{ |
|
dtype: await isWebGPUSupported() ? 'q4f16' : 'int8', |
|
device: await isWebGPUSupported() ? 'webgpu' : 'wasm', |
|
progress_callback: progress => globalThis.postMessage({ type: 'progress', data: { message: progress } } satisfies WorkerMessageEventProgress), |
|
}, |
|
).catch((error: Error) => { |
|
globalThis.postMessage({ type: 'error', data: { error, message: error.message } } satisfies WorkerMessageEventError<Error>) |
|
throw error |
|
}) |
|
|
|
tokenizer = await AutoTokenizer.from_pretrained( |
|
'HuggingFaceTB/SmolLM2-1.7B-Instruct', |
|
).catch((error: Error) => { |
|
globalThis.postMessage({ type: 'error', data: { error, message: error.message } } satisfies WorkerMessageEventError<Error>) |
|
throw error |
|
}) |
|
|
|
await llm.generate({ ...tokenizer('x'), max_new_tokens: 1 }) |
|
|
|
globalThis.postMessage({ |
|
type: 'status', |
|
data: { |
|
status: 'ready', |
|
message: 'Ready!', |
|
voices: tts.voices, |
|
}, |
|
} as WorkerMessageEventStatus) |
|
} |
|
|
|
loadModels() |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function vad(buffer?: Float32Array): Promise<boolean> { |
|
if (!buffer) { |
|
|
|
return false |
|
} |
|
|
|
const input = new Tensor('float32', buffer, [1, buffer.length]) |
|
|
|
const { stateN, output } = await silero_vad({ input, sr, state }) |
|
state = stateN |
|
|
|
const isSpeech = output.data[0] |
|
|
|
|
|
return ( |
|
|
|
isSpeech > SPEECH_THRESHOLD |
|
|
|
|| (isRecording && isSpeech >= EXIT_THRESHOLD) |
|
) |
|
} |
|
|
|
interface SpeechData { |
|
start: number |
|
end: number |
|
duration: number |
|
} |
|
|
|
type BatchEncodingItem = number[] | number[][] | Tensor |
|
|
|
|
|
|
|
interface BatchEncoding { |
|
|
|
|
|
|
|
input_ids: BatchEncodingItem |
|
|
|
|
|
|
|
attention_mask: BatchEncodingItem |
|
|
|
|
|
|
|
token_type_ids?: BatchEncodingItem |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function speechToSpeech(buffer: Float32Array, _data: SpeechData): Promise<void> { |
|
isPlaying = true |
|
|
|
|
|
const result = await transcriber(buffer) |
|
const text = (result as { text: string }).text.trim() |
|
|
|
if (['', '[BLANK_AUDIO]'].includes(text)) { |
|
|
|
return |
|
} |
|
|
|
messages.push({ role: 'user', content: text }) |
|
|
|
|
|
const splitter = new TextSplitterStream() |
|
const stream = tts!.stream(splitter, { voice }); |
|
(async () => { |
|
for await (const { text, audio } of stream) { |
|
globalThis.postMessage({ type: 'output', data: { text, result: audio } } satisfies WorkerMessageEventOutput) |
|
} |
|
})() |
|
|
|
|
|
const inputs = tokenizer.apply_chat_template(messages, { |
|
add_generation_prompt: true, |
|
return_dict: true, |
|
}) as BatchEncoding |
|
|
|
const streamer = new TextStreamer(tokenizer, { |
|
skip_prompt: true, |
|
skip_special_tokens: true, |
|
callback_function: (text: string) => { |
|
splitter.push(text) |
|
}, |
|
token_callback_function: () => {}, |
|
}) |
|
|
|
stopping_criteria = new InterruptableStoppingCriteria() |
|
type GenerationFunctionParameters = Parameters<typeof llm.generate>[0] & Record<string, any> |
|
|
|
const generatedRes = await llm.generate({ |
|
...inputs, |
|
past_key_values: past_key_values_cache, |
|
do_sample: false, |
|
max_new_tokens: 1024, |
|
streamer, |
|
stopping_criteria: stopping_criteria as unknown as StoppingCriteriaList, |
|
return_dict_in_generate: true, |
|
} as GenerationFunctionParameters) |
|
|
|
const { past_key_values, sequences } = generatedRes as CausalLMOutputWithPast & { sequences: Tensor } |
|
past_key_values_cache = past_key_values |
|
|
|
|
|
splitter.close() |
|
|
|
const decoded = tokenizer.batch_decode( |
|
|
|
sequences.slice(null, [(inputs.input_ids as Tensor).dims[1], null as any]), |
|
{ skip_special_tokens: true }, |
|
) |
|
|
|
messages.push({ role: 'assistant', content: decoded[0] }) |
|
} |
|
|
|
|
|
let postSpeechSamples = 0 |
|
function resetAfterRecording(offset = 0): void { |
|
globalThis.postMessage({ |
|
type: 'status', |
|
data: { |
|
status: 'recording_end', |
|
message: 'Transcribing...', |
|
duration: 'until_next', |
|
}, |
|
} satisfies WorkerMessageEventStatus) |
|
|
|
BUFFER.fill(0, offset) |
|
bufferPointer = offset |
|
isRecording = false |
|
postSpeechSamples = 0 |
|
} |
|
|
|
function dispatchForTranscriptionAndResetAudioBuffer(overflow?: Float32Array): void { |
|
|
|
const now = Date.now() |
|
const end |
|
= now - ((postSpeechSamples + SPEECH_PAD_SAMPLES) / INPUT_SAMPLE_RATE) * 1000 |
|
const start = end - (bufferPointer / INPUT_SAMPLE_RATE) * 1000 |
|
const duration = end - start |
|
const overflowLength = overflow?.length ?? 0 |
|
|
|
|
|
const buffer = BUFFER.slice(0, bufferPointer + SPEECH_PAD_SAMPLES) |
|
|
|
const prevLength = prevBuffers.reduce((acc, b) => acc + b.length, 0) |
|
const paddedBuffer = new Float32Array(prevLength + buffer.length) |
|
let offset = 0 |
|
for (const prev of prevBuffers) { |
|
paddedBuffer.set(prev, offset) |
|
offset += prev.length |
|
} |
|
paddedBuffer.set(buffer, offset) |
|
speechToSpeech(paddedBuffer, { start, end, duration }) |
|
|
|
|
|
if (overflow) { |
|
BUFFER.set(overflow, 0) |
|
} |
|
resetAfterRecording(overflowLength) |
|
} |
|
|
|
globalThis.onmessage = async (event: MessageEvent) => { |
|
const { type, buffer } = event.data |
|
|
|
|
|
if (type === 'audio' && isPlaying) |
|
return |
|
|
|
switch (type) { |
|
case 'start_call': { |
|
const name = tts!.voices[voice ?? 'af_heart']?.name ?? 'Heart' |
|
greet(`Hey there, my name is ${name}! How can I help you today?`) |
|
return |
|
} |
|
case 'end_call': |
|
messages = [SYSTEM_MESSAGE] |
|
past_key_values_cache = null |
|
break |
|
case 'interrupt': |
|
stopping_criteria?.interrupt() |
|
return |
|
case 'set_voice': |
|
voice = event.data.voice |
|
|
|
globalThis.postMessage({ |
|
type: 'set_voice_response', |
|
data: { |
|
ok: true, |
|
}, |
|
} satisfies WorkerMessageEventSetVoiceResponse) |
|
|
|
return |
|
case 'playback_ended': |
|
isPlaying = false |
|
return |
|
} |
|
|
|
const wasRecording = isRecording |
|
const isSpeech = await vad(buffer) |
|
|
|
if (!wasRecording && !isSpeech) { |
|
|
|
|
|
|
|
if (prevBuffers.length >= MAX_NUM_PREV_BUFFERS) { |
|
|
|
prevBuffers.shift() |
|
} |
|
prevBuffers.push(buffer) |
|
return |
|
} |
|
|
|
const remaining = BUFFER.length - bufferPointer |
|
if (buffer.length >= remaining) { |
|
|
|
|
|
BUFFER.set(buffer.subarray(0, remaining), bufferPointer) |
|
bufferPointer += remaining |
|
|
|
|
|
const overflow = buffer.subarray(remaining) |
|
dispatchForTranscriptionAndResetAudioBuffer(overflow) |
|
return |
|
} |
|
else { |
|
|
|
|
|
BUFFER.set(buffer, bufferPointer) |
|
bufferPointer += buffer.length |
|
} |
|
|
|
if (isSpeech) { |
|
if (!isRecording) { |
|
|
|
globalThis.postMessage({ |
|
type: 'status', |
|
data: { |
|
status: 'recording_start', |
|
message: 'Listening...', |
|
duration: 'until_next', |
|
}, |
|
} satisfies WorkerMessageEventStatus) |
|
} |
|
|
|
|
|
isRecording = true |
|
postSpeechSamples = 0 |
|
|
|
return |
|
} |
|
|
|
postSpeechSamples += buffer.length |
|
|
|
|
|
|
|
if (postSpeechSamples < MIN_SILENCE_DURATION_SAMPLES) { |
|
|
|
|
|
return |
|
} |
|
|
|
if (bufferPointer < MIN_SPEECH_DURATION_SAMPLES) { |
|
|
|
|
|
resetAfterRecording() |
|
return |
|
} |
|
|
|
dispatchForTranscriptionAndResetAudioBuffer() |
|
} |
|
|
|
function greet(text: string): void { |
|
isPlaying = true |
|
|
|
const splitter = new TextSplitterStream() |
|
const stream = tts!.stream(splitter, { voice }); |
|
|
|
(async () => { |
|
for await (const { text: chunkText, audio } of stream) { |
|
globalThis.postMessage({ type: 'output', data: { text: chunkText, result: audio } } satisfies WorkerMessageEventOutput) |
|
} |
|
})() |
|
|
|
splitter.push(text) |
|
splitter.close() |
|
messages.push({ role: 'assistant', content: text }) |
|
} |
|
|