|
import time |
|
import asyncio |
|
import numpy as np |
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from fastapi.responses import HTMLResponse |
|
|
|
|
|
from silero_vad import VADIterator, load_silero_vad |
|
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer |
|
|
|
|
|
SAMPLING_RATE = 16000 |
|
CHUNK_SIZE = 512 |
|
LOOKBACK_CHUNKS = 5 |
|
MAX_SPEECH_SECS = 15 |
|
MIN_REFRESH_SECS = 1/2 |
|
|
|
app = FastAPI() |
|
|
|
class Transcriber: |
|
def __init__(self, model_name: str, rate: int = 16000): |
|
if rate != 16000: |
|
raise ValueError("Moonshine supports sampling rate 16000 Hz.") |
|
self.model = MoonshineOnnxModel(model_name=model_name |
|
self.rate = rate |
|
self.tokenizer = load_tokenizer() |
|
|
|
self.inference_secs = 0 |
|
self.number_inferences = 0 |
|
self.speech_secs = 0 |
|
|
|
self.__call__(np.zeros(int(rate), dtype=np.float32)) |
|
|
|
def __call__(self, speech: np.ndarray) -> str: |
|
"""Returns a transcription of the given speech (a float32 numpy array).""" |
|
self.number_inferences += 1 |
|
self.speech_secs += len(speech) / self.rate |
|
start_time = time.time() |
|
tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32)) |
|
text = self.tokenizer.decode_batch(tokens)[0] |
|
self.inference_secs += time.time() - start_time |
|
return text |
|
|
|
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray: |
|
""" |
|
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1]. |
|
""" |
|
int_data = np.frombuffer(pcm_data, dtype=np.int16) |
|
float_data = int_data.astype(np.float32) / 32768.0 |
|
return float_data |
|
|
|
|
|
model_name_tiny = "moonshine/tiny" |
|
model_name_base = "moonshine/base" |
|
transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE) |
|
transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE) |
|
vad_model = load_silero_vad(onnx=True) |
|
vad_iterator = VADIterator( |
|
model=vad_model, |
|
sampling_rate=SAMPLING_RATE, |
|
threshold=0.5, |
|
min_silence_duration_ms=300, |
|
) |
|
|
|
@app.websocket("/ws/transcribe") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
|
|
caption_cache = [] |
|
lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE |
|
speech = np.empty(0, dtype=np.float32) |
|
recording = False |
|
last_partial_time = time.time() |
|
current_model = transcriber_tiny |
|
last_output = "" |
|
|
|
try: |
|
while True: |
|
data = await websocket.receive() |
|
if data["type"] == "websocket.receive": |
|
if data.get("text") == "switch_to_tiny": |
|
current_model = transcriber_tiny |
|
continue |
|
elif data.get("text") == "switch_to_base": |
|
current_model = transcriber_base |
|
continue |
|
|
|
chunk = pcm16_to_float32(data["bytes"]) |
|
speech = np.concatenate((speech, chunk)) |
|
if not recording: |
|
speech = speech[-lookback_size:] |
|
|
|
vad_result = vad_iterator(chunk) |
|
current_time = time.time() |
|
|
|
if vad_result: |
|
if "start" in vad_result and not recording: |
|
recording = True |
|
await websocket.send_json({"type": "status", "message": "speaking_started"}) |
|
|
|
if "end" in vad_result and recording: |
|
recording = False |
|
text = current_model(speech) |
|
await websocket.send_json({"type": "final", "transcript": text}) |
|
caption_cache.append(text) |
|
speech = np.empty(0, dtype=np.float32) |
|
vad_iterator.triggered = False |
|
vad_iterator.temp_end = 0 |
|
vad_iterator.current_sample = 0 |
|
await websocket.send_json({"type": "status", "message": "speaking_stopped"}) |
|
elif recording: |
|
|
|
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS: |
|
recording = False |
|
text = current_model(speech) |
|
await websocket.send_json({"type": "final", "transcript": text}) |
|
caption_cache.append(text) |
|
speech = speech[-10:] |
|
|
|
elif (current_time - last_partial_time) > MIN_REFRESH_SECS + 0.2*(len(speech) / SAMPLING_RATE): |
|
text = current_model(speech) |
|
if last_output != text: |
|
last_output = text |
|
await websocket.send_json({"type": "partial", "transcript": text}) |
|
last_partial_time = current_time |
|
except WebSocketDisconnect: |
|
if recording and speech.size: |
|
text = current_model(speech) |
|
await websocket.send_json({"type": "final", "transcript": text}) |
|
print("WebSocket disconnected") |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def get_home(): |
|
return """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>AssemblyAI Realtime Transcription</title> |
|
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet"> |
|
</head> |
|
<body class="bg-gray-100 p-6"> |
|
<div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md"> |
|
<h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1> |
|
<button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button> |
|
<select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4"> |
|
<option value="tiny">Tiny Model</option> |
|
<option value="base">Base Model</option> |
|
</select> |
|
<p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p> |
|
<p id="speakingStatus" class="text-gray-600 mb-4"></p> |
|
<div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div> |
|
<div id="visualizer" class="border p-4 rounded h-64"> |
|
<canvas id="audioCanvas" class="w-full h-full"></canvas> |
|
</div> |
|
</div> |
|
<script> |
|
let ws; |
|
let audioContext; |
|
let scriptProcessor; |
|
let mediaStream; |
|
let currentLine = document.createElement('span'); |
|
let analyser; |
|
let canvas, canvasContext; |
|
|
|
document.getElementById('transcription').appendChild(currentLine); |
|
canvas = document.getElementById('audioCanvas'); |
|
canvasContext = canvas.getContext('2d'); |
|
|
|
async function startTranscription() { |
|
document.getElementById("status").innerText = "Connecting..."; |
|
ws = new WebSocket("wss://" + location.host + "/ws/transcribe"); |
|
ws.binaryType = 'arraybuffer'; |
|
|
|
ws.onopen = async function() { |
|
document.getElementById("status").innerText = "Connected"; |
|
try { |
|
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true }); |
|
audioContext = new AudioContext({ sampleRate: 16000 }); |
|
const source = audioContext.createMediaStreamSource(mediaStream); |
|
analyser = audioContext.createAnalyser(); |
|
analyser.fftSize = 2048; |
|
const bufferLength = analyser.frequencyBinCount; |
|
const dataArray = new Uint8Array(bufferLength); |
|
source.connect(analyser); |
|
scriptProcessor = audioContext.createScriptProcessor(512, 1, 1); |
|
scriptProcessor.onaudioprocess = function(event) { |
|
const inputData = event.inputBuffer.getChannelData(0); |
|
const pcm16 = floatTo16BitPCM(inputData); |
|
if (ws.readyState === WebSocket.OPEN) { |
|
ws.send(pcm16); |
|
} |
|
analyser.getByteTimeDomainData(dataArray); |
|
canvasContext.fillStyle = 'rgb(200, 200, 200)'; |
|
canvasContext.fillRect(0, 0, canvas.width, canvas.height); |
|
canvasContext.lineWidth = 2; |
|
canvasContext.strokeStyle = 'rgb(0, 0, 0)'; |
|
canvasContext.beginPath(); |
|
let sliceWidth = canvas.width * 1.0 / bufferLength; |
|
let x = 0; |
|
for (let i = 0; i < bufferLength; i++) { |
|
let v = dataArray[i] / 128.0; |
|
let y = v * canvas.height / 2; |
|
if (i === 0) { |
|
canvasContext.moveTo(x, y); |
|
} else { |
|
canvasContext.lineTo(x, y); |
|
} |
|
x += sliceWidth; |
|
} |
|
canvasContext.lineTo(canvas.width, canvas.height / 2); |
|
canvasContext.stroke(); |
|
}; |
|
source.connect(scriptProcessor); |
|
scriptProcessor.connect(audioContext.destination); |
|
} catch (err) { |
|
document.getElementById("status").innerText = "Error: " + err; |
|
} |
|
}; |
|
|
|
ws.onmessage = function(event) { |
|
const data = JSON.parse(event.data); |
|
if (data.type === 'partial') { |
|
currentLine.style.color = 'gray'; |
|
currentLine.textContent = data.transcript + ' '; |
|
} else if (data.type === 'final') { |
|
currentLine.style.color = 'black'; |
|
currentLine.textContent = data.transcript; |
|
currentLine = document.createElement('span'); |
|
document.getElementById('transcription').appendChild(document.createElement('br')); |
|
document.getElementById('transcription').appendChild(currentLine); |
|
} else if (data.type === 'status') { |
|
if (data.message === 'speaking_started') { |
|
document.getElementById("speakingStatus").innerText = "Speaking Started"; |
|
document.getElementById("speakingStatus").style.color = "green"; |
|
} else if (data.message === 'speaking_stopped') { |
|
document.getElementById("speakingStatus").innerText = "Speaking Stopped"; |
|
document.getElementById("speakingStatus").style.color = "red"; |
|
} |
|
} |
|
}; |
|
|
|
ws.onclose = function() { |
|
if (audioContext && audioContext.state !== 'closed') { |
|
audioContext.close(); |
|
} |
|
document.getElementById("status").innerText = "Closed"; |
|
}; |
|
} |
|
|
|
function switchModel() { |
|
const model = document.getElementById("modelSelect").value; |
|
if (ws && ws.readyState === WebSocket.OPEN) { |
|
if (model === "tiny") { |
|
ws.send("switch_to_tiny"); |
|
} else if (model === "base") { |
|
ws.send("switch_to_base"); |
|
} |
|
} |
|
} |
|
|
|
function floatTo16BitPCM(input) { |
|
const buffer = new ArrayBuffer(input.length * 2); |
|
const output = new DataView(buffer); |
|
for (let i = 0; i < input.length; i++) { |
|
let s = Math.max(-1, Math.min(1, input[i])); |
|
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true); |
|
} |
|
return buffer; |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|