|
|
|
|
|
import { ModelCache } from './model-cache'; |
|
import { listChatModelsIterator } from './list-chat-models.js'; |
|
|
|
import curatedList from './curated-model-list.json' assert { type: 'json' }; |
|
|
|
export function bootWorker() { |
|
const modelCache = new ModelCache(); |
|
let selectedModel = modelCache.knownModels[0]; |
|
|
|
try { |
|
self.postMessage({ type: 'status', status: 'initializing' }); |
|
} catch (e) { |
|
|
|
} |
|
|
|
self.postMessage({ type: 'status', status: 'backend-detected', backend: modelCache.backend }); |
|
|
|
|
|
|
|
self.postMessage({ type: 'ready', env: modelCache.env, backend: modelCache.backend }); |
|
|
|
|
|
self.addEventListener('message', handleMessage); |
|
|
|
const activeTasks = new Map(); |
|
|
|
async function handleMessage({ data }) { |
|
const { id } = data; |
|
try { |
|
if (data.type === 'listChatModels') { |
|
|
|
handleListChatModels(data).catch(err => { |
|
self.postMessage({ id, type: 'error', error: String(err) }); |
|
}); |
|
} else if (data.type === 'cancelListChatModels') { |
|
const task = activeTasks.get(id); |
|
if (task && task.abort) task.abort(); |
|
self.postMessage({ id, type: 'response', result: { cancelled: true } }); |
|
} else if (data.type === 'loadModel') { |
|
const { modelName = modelCache.knownModels[0] } = data; |
|
try { |
|
const pipe = await modelCache.getModel({ modelName }); |
|
selectedModel = modelName; |
|
self.postMessage({ id, type: 'response', result: { model: modelName, status: 'loaded' } }); |
|
} catch (err) { |
|
self.postMessage({ id, type: 'error', error: String(err) }); |
|
} |
|
} else if (data.type === 'runPrompt') { |
|
handleRunPrompt(data); |
|
} else { |
|
if (id) self.postMessage({ id, type: 'error', error: 'unknown-message-type' }); |
|
} |
|
} catch (err) { |
|
if (id) self.postMessage({ id, type: 'error', error: String(err) }); |
|
} |
|
} |
|
|
|
async function handleRunPrompt({ prompt, modelName = selectedModel, id, options }) { |
|
try { |
|
const engine = await modelCache.getModel({ modelName }); |
|
if (!engine) throw new Error('engine not available'); |
|
|
|
self.postMessage({ id, type: 'status', status: 'inference-start', model: modelName }); |
|
|
|
|
|
let text; |
|
if ((engine).chat?.completions?.create) { |
|
|
|
try { |
|
const webllmEngine = (engine); |
|
const response = await webllmEngine.chat.completions.create({ |
|
messages: [{ role: "user", content: prompt }], |
|
max_tokens: options?.max_new_tokens ?? 250, |
|
temperature: options?.temperature ?? 0.7 |
|
}); |
|
text = response.choices[0]?.message?.content ?? ''; |
|
} catch (err) { |
|
console.log(`WebLLM inference failed for ${modelName}: ${err.message}`); |
|
throw err; |
|
} |
|
} else if (typeof engine === 'function') { |
|
|
|
const out = await engine(prompt, { |
|
max_new_tokens: 250, |
|
temperature: 0.7, |
|
do_sample: true, |
|
pad_token_id: engine.tokenizer?.eos_token_id, |
|
return_full_text: false, |
|
...options |
|
}); |
|
text = extractText(out); |
|
} else { |
|
throw new Error('Unknown engine type'); |
|
} |
|
|
|
self.postMessage({ id, type: 'status', status: 'inference-done', model: modelName }); |
|
self.postMessage({ id, type: 'response', result: text }); |
|
} catch (err) { |
|
self.postMessage({ id, type: 'error', error: String(err) }); |
|
} |
|
} |
|
|
|
|
|
async function handleListChatModels({ id, params = {} }) { |
|
|
|
self.postMessage({ id, type: 'response', result: { models: curatedList } }); |
|
return; |
|
|
|
const iterator = listChatModelsIterator(params); |
|
let sawDone = false; |
|
|
|
let batchBuffer = []; |
|
let batchTimer = null; |
|
const BATCH_MS = 50; |
|
const BATCH_MAX = 50; |
|
|
|
function flushBatch() { |
|
if (!batchBuffer || batchBuffer.length === 0) return; |
|
try { |
|
console.log('Loading: ', batchBuffer[batchBuffer.length - 1]); |
|
self.postMessage({ id, type: 'progress', batch: true, items: batchBuffer.splice(0) }); |
|
} catch (e) {} |
|
if (batchTimer) { clearTimeout(batchTimer); batchTimer = null; } |
|
} |
|
|
|
function enqueueProgress(delta) { |
|
batchBuffer.push(delta); |
|
if (batchBuffer.length >= BATCH_MAX) return flushBatch(); |
|
if (!batchTimer) { |
|
batchTimer = setTimeout(() => { flushBatch(); }, BATCH_MS); |
|
} |
|
} |
|
|
|
activeTasks.set(id, { abort: () => iterator.return() }); |
|
let lastBatchDelta; |
|
try { |
|
for await (const delta of iterator) { |
|
try { enqueueProgress(delta); } catch (e) { } |
|
if (delta.models) lastBatchDelta = delta; |
|
if (delta && delta.status === 'done') { |
|
sawDone = true; |
|
} |
|
} |
|
|
|
|
|
flushBatch(); |
|
if (!sawDone) { |
|
|
|
self.postMessage({ id, type: 'response', result: { cancelled: true } }); |
|
} else { |
|
self.postMessage({ id, type: 'response', result: lastBatchDelta }); |
|
} |
|
} catch (err) { |
|
flushBatch(); |
|
self.postMessage({ id, type: 'error', error: String(err), code: err.code || null }); |
|
} finally { |
|
activeTasks.delete(id); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
function extractText(output) { |
|
|
|
try { |
|
if (!output) return ''; |
|
if (typeof output === 'string') return output; |
|
if (Array.isArray(output) && output.length > 0) { |
|
return output.map(el => { |
|
if (el.generated_text) return el.generated_text; |
|
if (el.text) return el.text; |
|
|
|
if (typeof el === 'string') return el; |
|
}); |
|
} |
|
|
|
return String(output); |
|
} catch (e) { |
|
return ''; |
|
} |
|
} |