Spaces:
bcci
/
Runtime error

stt-3 / app.py
bcci's picture
Update app.py
b65b610 verified
import time
import asyncio
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
# Import your model and VAD libraries.
from silero_vad import VADIterator, load_silero_vad
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
# Constants
SAMPLING_RATE = 16000
CHUNK_SIZE = 512 # Required for Silero VAD at 16kHz.
LOOKBACK_CHUNKS = 5
MAX_SPEECH_SECS = 15 # Maximum duration for a single transcription segment.
MIN_REFRESH_SECS = 1/2 # Minimum interval for sending partial updates.
app = FastAPI()
class Transcriber:
def __init__(self, model_name: str, rate: int = 16000):
if rate != 16000:
raise ValueError("Moonshine supports sampling rate 16000 Hz.")
self.model = MoonshineOnnxModel(model_name=model_name # model_precision="quantized" for quantized model, by default its float
self.rate = rate
self.tokenizer = load_tokenizer()
# Statistics (optional)
self.inference_secs = 0
self.number_inferences = 0
self.speech_secs = 0
# Warmup run.
self.__call__(np.zeros(int(rate), dtype=np.float32))
def __call__(self, speech: np.ndarray) -> str:
"""Returns a transcription of the given speech (a float32 numpy array)."""
self.number_inferences += 1
self.speech_secs += len(speech) / self.rate
start_time = time.time()
tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
text = self.tokenizer.decode_batch(tokens)[0]
self.inference_secs += time.time() - start_time
return text
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
"""
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
"""
int_data = np.frombuffer(pcm_data, dtype=np.int16)
float_data = int_data.astype(np.float32) / 32768.0
return float_data
# Initialize models.
model_name_tiny = "moonshine/tiny"
model_name_base = "moonshine/base"
transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(
model=vad_model,
sampling_rate=SAMPLING_RATE,
threshold=0.5,
min_silence_duration_ms=300,
)
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
caption_cache = []
lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
speech = np.empty(0, dtype=np.float32)
recording = False
last_partial_time = time.time()
current_model = transcriber_tiny # Default to tiny model
last_output = ""
try:
while True:
data = await websocket.receive()
if data["type"] == "websocket.receive":
if data.get("text") == "switch_to_tiny":
current_model = transcriber_tiny
continue
elif data.get("text") == "switch_to_base":
current_model = transcriber_base
continue
chunk = pcm16_to_float32(data["bytes"])
speech = np.concatenate((speech, chunk))
if not recording:
speech = speech[-lookback_size:]
vad_result = vad_iterator(chunk)
current_time = time.time()
if vad_result:
if "start" in vad_result and not recording:
recording = True
await websocket.send_json({"type": "status", "message": "speaking_started"})
if "end" in vad_result and recording:
recording = False
text = current_model(speech)
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = np.empty(0, dtype=np.float32)
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
elif recording:
# print(len(speech) / SAMPLING_RATE)
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
recording = False
text = current_model(speech)
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = speech[-10:]
elif (current_time - last_partial_time) > MIN_REFRESH_SECS + 0.2*(len(speech) / SAMPLING_RATE):
text = current_model(speech)
if last_output != text:
last_output = text
await websocket.send_json({"type": "partial", "transcript": text})
last_partial_time = current_time
except WebSocketDisconnect:
if recording and speech.size:
text = current_model(speech)
await websocket.send_json({"type": "final", "transcript": text})
print("WebSocket disconnected")
@app.get("/", response_class=HTMLResponse)
async def get_home():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>AssemblyAI Realtime Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-100 p-6">
<div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1>
<button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button>
<select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4">
<option value="tiny">Tiny Model</option>
<option value="base">Base Model</option>
</select>
<p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p>
<p id="speakingStatus" class="text-gray-600 mb-4"></p>
<div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div>
<div id="visualizer" class="border p-4 rounded h-64">
<canvas id="audioCanvas" class="w-full h-full"></canvas>
</div>
</div>
<script>
let ws;
let audioContext;
let scriptProcessor;
let mediaStream;
let currentLine = document.createElement('span');
let analyser;
let canvas, canvasContext;
document.getElementById('transcription').appendChild(currentLine);
canvas = document.getElementById('audioCanvas');
canvasContext = canvas.getContext('2d');
async function startTranscription() {
document.getElementById("status").innerText = "Connecting...";
ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
ws.binaryType = 'arraybuffer';
ws.onopen = async function() {
document.getElementById("status").innerText = "Connected";
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new AudioContext({ sampleRate: 16000 });
const source = audioContext.createMediaStreamSource(mediaStream);
analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
source.connect(analyser);
scriptProcessor = audioContext.createScriptProcessor(512, 1, 1);
scriptProcessor.onaudioprocess = function(event) {
const inputData = event.inputBuffer.getChannelData(0);
const pcm16 = floatTo16BitPCM(inputData);
if (ws.readyState === WebSocket.OPEN) {
ws.send(pcm16);
}
analyser.getByteTimeDomainData(dataArray);
canvasContext.fillStyle = 'rgb(200, 200, 200)';
canvasContext.fillRect(0, 0, canvas.width, canvas.height);
canvasContext.lineWidth = 2;
canvasContext.strokeStyle = 'rgb(0, 0, 0)';
canvasContext.beginPath();
let sliceWidth = canvas.width * 1.0 / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
let v = dataArray[i] / 128.0;
let y = v * canvas.height / 2;
if (i === 0) {
canvasContext.moveTo(x, y);
} else {
canvasContext.lineTo(x, y);
}
x += sliceWidth;
}
canvasContext.lineTo(canvas.width, canvas.height / 2);
canvasContext.stroke();
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
document.getElementById("status").innerText = "Error: " + err;
}
};
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
if (data.type === 'partial') {
currentLine.style.color = 'gray';
currentLine.textContent = data.transcript + ' ';
} else if (data.type === 'final') {
currentLine.style.color = 'black';
currentLine.textContent = data.transcript;
currentLine = document.createElement('span');
document.getElementById('transcription').appendChild(document.createElement('br'));
document.getElementById('transcription').appendChild(currentLine);
} else if (data.type === 'status') {
if (data.message === 'speaking_started') {
document.getElementById("speakingStatus").innerText = "Speaking Started";
document.getElementById("speakingStatus").style.color = "green";
} else if (data.message === 'speaking_stopped') {
document.getElementById("speakingStatus").innerText = "Speaking Stopped";
document.getElementById("speakingStatus").style.color = "red";
}
}
};
ws.onclose = function() {
if (audioContext && audioContext.state !== 'closed') {
audioContext.close();
}
document.getElementById("status").innerText = "Closed";
};
}
function switchModel() {
const model = document.getElementById("modelSelect").value;
if (ws && ws.readyState === WebSocket.OPEN) {
if (model === "tiny") {
ws.send("switch_to_tiny");
} else if (model === "base") {
ws.send("switch_to_base");
}
}
}
function floatTo16BitPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const output = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
let s = Math.max(-1, Math.min(1, input[i]));
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)