Manjunath Kudlur commited on
Commit
ce2e9d0
·
1 Parent(s): 115eadd

Choose backend

Browse files
Files changed (4) hide show
  1. decoder_worker.js +7 -4
  2. encoder_worker.js +6 -3
  3. index.html +26 -0
  4. streaming_asr.js +23 -2
decoder_worker.js CHANGED
@@ -376,8 +376,11 @@ async function processMessage(e) {
376
  cfg = data.cfg;
377
  const onnxUrl = data.onnxUrl;
378
  const modelName = data.modelName;
 
379
  const dtype = 'fp32';
380
 
 
 
381
  tailLatency = cfg.n_future * cfg.encoder_depth;
382
 
383
  // Load tokenizer
@@ -394,7 +397,7 @@ async function processMessage(e) {
394
  self.postMessage({ type: 'status', message: 'Loading adapter...' });
395
  self.postMessage({ type: 'model_start', model: 'Adapter' });
396
  const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter');
397
- adapterSession = await ort.InferenceSession.create(adapterBuffer);
398
  self.postMessage({ type: 'model_done', model: 'Adapter' });
399
 
400
  // Initialize decoder init
@@ -402,7 +405,7 @@ async function processMessage(e) {
402
  self.postMessage({ type: 'status', message: 'Loading decoder (init)...' });
403
  self.postMessage({ type: 'model_start', model: 'Decoder Init' });
404
  const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init');
405
- decoderInitSession = await ort.InferenceSession.create(decInitBuffer);
406
  self.postMessage({ type: 'model_done', model: 'Decoder Init' });
407
 
408
  // Initialize decoder step
@@ -410,10 +413,10 @@ async function processMessage(e) {
410
  self.postMessage({ type: 'status', message: 'Loading decoder (step)...' });
411
  self.postMessage({ type: 'model_start', model: 'Decoder Step' });
412
  const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step');
413
- decoderStepSession = await ort.InferenceSession.create(decStepBuffer);
414
  self.postMessage({ type: 'model_done', model: 'Decoder Step' });
415
 
416
- self.postMessage({ type: 'ready' });
417
  } catch (err) {
418
  self.postMessage({ type: 'error', message: err.message });
419
  }
 
376
  cfg = data.cfg;
377
  const onnxUrl = data.onnxUrl;
378
  const modelName = data.modelName;
379
+ const backend = data.backend || 'wasm';
380
  const dtype = 'fp32';
381
 
382
+ const sessionOptions = { executionProviders: [backend] };
383
+
384
  tailLatency = cfg.n_future * cfg.encoder_depth;
385
 
386
  // Load tokenizer
 
397
  self.postMessage({ type: 'status', message: 'Loading adapter...' });
398
  self.postMessage({ type: 'model_start', model: 'Adapter' });
399
  const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter');
400
+ adapterSession = await ort.InferenceSession.create(adapterBuffer, sessionOptions);
401
  self.postMessage({ type: 'model_done', model: 'Adapter' });
402
 
403
  // Initialize decoder init
 
405
  self.postMessage({ type: 'status', message: 'Loading decoder (init)...' });
406
  self.postMessage({ type: 'model_start', model: 'Decoder Init' });
407
  const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init');
408
+ decoderInitSession = await ort.InferenceSession.create(decInitBuffer, sessionOptions);
409
  self.postMessage({ type: 'model_done', model: 'Decoder Init' });
410
 
411
  // Initialize decoder step
 
413
  self.postMessage({ type: 'status', message: 'Loading decoder (step)...' });
414
  self.postMessage({ type: 'model_start', model: 'Decoder Step' });
415
  const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step');
416
+ decoderStepSession = await ort.InferenceSession.create(decStepBuffer, sessionOptions);
417
  self.postMessage({ type: 'model_done', model: 'Decoder Step' });
418
 
419
+ self.postMessage({ type: 'ready', backend: backend });
420
  } catch (err) {
421
  self.postMessage({ type: 'error', message: err.message });
422
  }
encoder_worker.js CHANGED
@@ -241,8 +241,11 @@ async function processMessage(e) {
241
  cfg = data.cfg;
242
  const onnxUrl = data.onnxUrl;
243
  const modelName = data.modelName;
 
244
  const dtype = 'fp32';
245
 
 
 
246
  tailLatency = cfg.n_future * cfg.encoder_depth;
247
 
248
  // Initialize preprocessor
@@ -250,7 +253,7 @@ async function processMessage(e) {
250
  self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
251
  self.postMessage({ type: 'model_start', model: 'Preprocessor' });
252
  const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor');
253
- prepSession = await ort.InferenceSession.create(prepBuffer);
254
  self.postMessage({ type: 'model_done', model: 'Preprocessor' });
255
 
256
  prepDim = cfg.dim;
@@ -263,7 +266,7 @@ async function processMessage(e) {
263
  self.postMessage({ type: 'status', message: 'Loading encoder...' });
264
  self.postMessage({ type: 'model_start', model: 'Encoder' });
265
  const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder');
266
- encSession = await ort.InferenceSession.create(encBuffer);
267
  self.postMessage({ type: 'model_done', model: 'Encoder' });
268
 
269
  encDim = cfg.dim;
@@ -272,7 +275,7 @@ async function processMessage(e) {
272
  encEncoderDepth = cfg.encoder_depth;
273
  encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future);
274
 
275
- self.postMessage({ type: 'ready' });
276
  } catch (err) {
277
  self.postMessage({ type: 'error', message: err.message });
278
  }
 
241
  cfg = data.cfg;
242
  const onnxUrl = data.onnxUrl;
243
  const modelName = data.modelName;
244
+ const backend = data.backend || 'wasm';
245
  const dtype = 'fp32';
246
 
247
+ const sessionOptions = { executionProviders: [backend] };
248
+
249
  tailLatency = cfg.n_future * cfg.encoder_depth;
250
 
251
  // Initialize preprocessor
 
253
  self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
254
  self.postMessage({ type: 'model_start', model: 'Preprocessor' });
255
  const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor');
256
+ prepSession = await ort.InferenceSession.create(prepBuffer, sessionOptions);
257
  self.postMessage({ type: 'model_done', model: 'Preprocessor' });
258
 
259
  prepDim = cfg.dim;
 
266
  self.postMessage({ type: 'status', message: 'Loading encoder...' });
267
  self.postMessage({ type: 'model_start', model: 'Encoder' });
268
  const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder');
269
+ encSession = await ort.InferenceSession.create(encBuffer, sessionOptions);
270
  self.postMessage({ type: 'model_done', model: 'Encoder' });
271
 
272
  encDim = cfg.dim;
 
275
  encEncoderDepth = cfg.encoder_depth;
276
  encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future);
277
 
278
+ self.postMessage({ type: 'ready', backend: backend });
279
  } catch (err) {
280
  self.postMessage({ type: 'error', message: err.message });
281
  }
index.html CHANGED
@@ -57,6 +57,23 @@
57
  .status-dot.listening { background: #00ff88; animation: pulse 1s infinite; }
58
  .status-dot.recording { background: #ff4444; animation: pulse 0.5s infinite; }
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @keyframes pulse {
61
  0%, 100% { opacity: 1; }
62
  50% { opacity: 0.5; }
@@ -622,6 +639,14 @@
622
  <option value="spindlier">Moonshine Spindlier</option>
623
  </select>
624
  </div>
 
 
 
 
 
 
 
 
625
  <div class="config-item">
626
  <label>ONNX Files URL</label>
627
  <input type="text" id="onnxUrl" placeholder="e.g., ./models or https://..." value="./models">
@@ -641,6 +666,7 @@
641
  <div class="status-indicator">
642
  <div class="status-dot" id="statusDot"></div>
643
  <span id="statusText">Ready</span>
 
644
  </div>
645
  <div class="controls">
646
  <button class="btn-primary" id="startBtn">Start Listening</button>
 
57
  .status-dot.listening { background: #00ff88; animation: pulse 1s infinite; }
58
  .status-dot.recording { background: #ff4444; animation: pulse 0.5s infinite; }
59
 
60
+ .backend-badge {
61
+ display: none;
62
+ padding: 3px 8px;
63
+ border-radius: 10px;
64
+ font-size: 11px;
65
+ font-weight: 600;
66
+ text-transform: uppercase;
67
+ margin-left: 10px;
68
+ background: #444;
69
+ color: #ccc;
70
+ }
71
+
72
+ .backend-badge.visible { display: inline-block; }
73
+ .backend-badge.wasm { background: #555; color: #aaa; }
74
+ .backend-badge.webgl { background: #f90; color: #000; }
75
+ .backend-badge.webgpu { background: linear-gradient(90deg, #00d4ff, #00ff88); color: #000; }
76
+
77
  @keyframes pulse {
78
  0%, 100% { opacity: 1; }
79
  50% { opacity: 0.5; }
 
639
  <option value="spindlier">Moonshine Spindlier</option>
640
  </select>
641
  </div>
642
+ <div class="config-item">
643
+ <label>Backend</label>
644
+ <select id="backendSelect">
645
+ <option value="wasm">WASM (CPU)</option>
646
+ <option value="webgl">WebGL (GPU)</option>
647
+ <option value="webgpu">WebGPU (GPU)</option>
648
+ </select>
649
+ </div>
650
  <div class="config-item">
651
  <label>ONNX Files URL</label>
652
  <input type="text" id="onnxUrl" placeholder="e.g., ./models or https://..." value="./models">
 
666
  <div class="status-indicator">
667
  <div class="status-dot" id="statusDot"></div>
668
  <span id="statusText">Ready</span>
669
+ <span class="backend-badge" id="backendBadge"></span>
670
  </div>
671
  <div class="controls">
672
  <button class="btn-primary" id="startBtn">Start Listening</button>
streaming_asr.js CHANGED
@@ -187,6 +187,7 @@ class PipelinedStreamingASR {
187
  constructor(config) {
188
  this.modelName = config.modelName || 'sleeker';
189
  this.onnxUrl = config.onnxUrl || './models';
 
190
  this.onsetThreshold = config.onsetThreshold || 0.5;
191
  this.offsetThreshold = config.offsetThreshold || 0.3;
192
  this.emaAlpha = config.emaAlpha || 0.3;
@@ -236,6 +237,10 @@ class PipelinedStreamingASR {
236
  this.onLiveCaption = null;
237
  this.onStatusUpdate = null;
238
  this.onQueueUpdate = null;
 
 
 
 
239
  }
240
 
241
  async loadModels(progressCallback, detailedProgressCallback) {
@@ -348,7 +353,8 @@ class PipelinedStreamingASR {
348
  data: {
349
  cfg: this.cfg,
350
  onnxUrl: this.onnxUrl,
351
- modelName: this.modelName
 
352
  }
353
  });
354
  });
@@ -364,6 +370,8 @@ class PipelinedStreamingASR {
364
  switch (type) {
365
  case 'ready':
366
  this.decoderReady = true;
 
 
367
  resolve();
368
  break;
369
  case 'error':
@@ -391,7 +399,8 @@ class PipelinedStreamingASR {
391
  data: {
392
  cfg: this.cfg,
393
  onnxUrl: this.onnxUrl,
394
- modelName: this.modelName
 
395
  }
396
  });
397
  });
@@ -733,7 +742,9 @@ class ASRDemoUI {
733
  this.liveCaptionText = document.getElementById('liveCaptionText');
734
  this.liveCaptionMobile = document.getElementById('liveCaptionMobile');
735
  this.liveCaptionTextMobile = document.getElementById('liveCaptionTextMobile');
 
736
  this.modelSelect = document.getElementById('modelSelect');
 
737
  this.onnxUrl = document.getElementById('onnxUrl');
738
  this.onsetThreshold = document.getElementById('onsetThreshold');
739
  this.offsetThreshold = document.getElementById('offsetThreshold');
@@ -782,6 +793,7 @@ class ASRDemoUI {
782
  const config = {
783
  modelName: this.modelSelect.value,
784
  onnxUrl: this.onnxUrl.value || './models',
 
785
  onsetThreshold: parseFloat(this.onsetThreshold.value),
786
  offsetThreshold: parseFloat(this.offsetThreshold.value)
787
  };
@@ -792,6 +804,7 @@ class ASRDemoUI {
792
  this.asr.onTranscript = (text, segmentId) => this.addTranscript(text, segmentId);
793
  this.asr.onLiveCaption = (text) => this.updateLiveCaption(text);
794
  this.asr.onStatusUpdate = (status, text) => this.updateStatus(status, text);
 
795
 
796
  await this.asr.loadModels(
797
  (text) => {
@@ -826,6 +839,7 @@ class ASRDemoUI {
826
  this.stopBtn.disabled = true;
827
  this.disableConfig(false);
828
  this.updateStatus('idle', 'Ready');
 
829
  }
830
 
831
  updateVadDisplay(prob, history, segmentEvents = [], historyStartTime = 0) {
@@ -978,6 +992,12 @@ class ASRDemoUI {
978
  this.statusText.textContent = text;
979
  }
980
 
 
 
 
 
 
 
981
  showLoading(text) {
982
  this.loadingText.textContent = text;
983
  this.loadingProgressFill.style.width = '0%';
@@ -1021,6 +1041,7 @@ class ASRDemoUI {
1021
 
1022
  disableConfig(disabled) {
1023
  this.modelSelect.disabled = disabled;
 
1024
  this.onnxUrl.disabled = disabled;
1025
  this.onsetThreshold.disabled = disabled;
1026
  this.offsetThreshold.disabled = disabled;
 
187
  constructor(config) {
188
  this.modelName = config.modelName || 'sleeker';
189
  this.onnxUrl = config.onnxUrl || './models';
190
+ this.backendChoice = config.backend || 'wasm';
191
  this.onsetThreshold = config.onsetThreshold || 0.5;
192
  this.offsetThreshold = config.offsetThreshold || 0.3;
193
  this.emaAlpha = config.emaAlpha || 0.3;
 
237
  this.onLiveCaption = null;
238
  this.onStatusUpdate = null;
239
  this.onQueueUpdate = null;
240
+ this.onBackendUpdate = null;
241
+
242
+ // Backend info
243
+ this.backend = 'unknown';
244
  }
245
 
246
  async loadModels(progressCallback, detailedProgressCallback) {
 
353
  data: {
354
  cfg: this.cfg,
355
  onnxUrl: this.onnxUrl,
356
+ modelName: this.modelName,
357
+ backend: this.backendChoice
358
  }
359
  });
360
  });
 
370
  switch (type) {
371
  case 'ready':
372
  this.decoderReady = true;
373
+ this.backend = e.data.backend || 'wasm';
374
+ this.onBackendUpdate?.(this.backend);
375
  resolve();
376
  break;
377
  case 'error':
 
399
  data: {
400
  cfg: this.cfg,
401
  onnxUrl: this.onnxUrl,
402
+ modelName: this.modelName,
403
+ backend: this.backendChoice
404
  }
405
  });
406
  });
 
742
  this.liveCaptionText = document.getElementById('liveCaptionText');
743
  this.liveCaptionMobile = document.getElementById('liveCaptionMobile');
744
  this.liveCaptionTextMobile = document.getElementById('liveCaptionTextMobile');
745
+ this.backendBadge = document.getElementById('backendBadge');
746
  this.modelSelect = document.getElementById('modelSelect');
747
+ this.backendSelect = document.getElementById('backendSelect');
748
  this.onnxUrl = document.getElementById('onnxUrl');
749
  this.onsetThreshold = document.getElementById('onsetThreshold');
750
  this.offsetThreshold = document.getElementById('offsetThreshold');
 
793
  const config = {
794
  modelName: this.modelSelect.value,
795
  onnxUrl: this.onnxUrl.value || './models',
796
+ backend: this.backendSelect.value,
797
  onsetThreshold: parseFloat(this.onsetThreshold.value),
798
  offsetThreshold: parseFloat(this.offsetThreshold.value)
799
  };
 
804
  this.asr.onTranscript = (text, segmentId) => this.addTranscript(text, segmentId);
805
  this.asr.onLiveCaption = (text) => this.updateLiveCaption(text);
806
  this.asr.onStatusUpdate = (status, text) => this.updateStatus(status, text);
807
+ this.asr.onBackendUpdate = (backend) => this.updateBackendBadge(backend);
808
 
809
  await this.asr.loadModels(
810
  (text) => {
 
839
  this.stopBtn.disabled = true;
840
  this.disableConfig(false);
841
  this.updateStatus('idle', 'Ready');
842
+ this.backendBadge.classList.remove('visible');
843
  }
844
 
845
  updateVadDisplay(prob, history, segmentEvents = [], historyStartTime = 0) {
 
992
  this.statusText.textContent = text;
993
  }
994
 
995
+ updateBackendBadge(backend) {
996
+ const labels = { 'wasm': 'WASM', 'webgl': 'WebGL', 'webgpu': 'WebGPU' };
997
+ this.backendBadge.textContent = labels[backend] || backend;
998
+ this.backendBadge.className = 'backend-badge visible ' + backend;
999
+ }
1000
+
1001
  showLoading(text) {
1002
  this.loadingText.textContent = text;
1003
  this.loadingProgressFill.style.width = '0%';
 
1041
 
1042
  disableConfig(disabled) {
1043
  this.modelSelect.disabled = disabled;
1044
+ this.backendSelect.disabled = disabled;
1045
  this.onnxUrl.disabled = disabled;
1046
  this.onsetThreshold.disabled = disabled;
1047
  this.offsetThreshold.disabled = disabled;