MMS-proxyapi / app.py
FredyHoundayi's picture
Add LID endpoint using facebook/mms-lid-256
da466fe
import io
import torch
import torch.nn.functional as F
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification
app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
MODEL_ID = "facebook/mms-1b-all"
LID_MODEL_ID = "facebook/mms-lid-256"
processor = None
model = None
lid_extractor = None
lid_model = None
@app.on_event("startup")
async def load_model():
global processor, model, lid_extractor, lid_model
print("Loading MMS ASR model...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()
print("Loading MMS LID model...")
lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID)
lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID)
lid_model.eval()
print("All models loaded.")
@app.get("/")
def root():
return {"message": "MMS Speech-to-Text API", "model": MODEL_ID}
@app.get("/health")
def health():
return {
"status": "ok",
"asr_model_loaded": model is not None,
"lid_model_loaded": lid_model is not None,
}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
audio_bytes = await file.read()
try:
audio, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)
predicted_ids = torch.argmax(probs, dim=-1)[0]
token_probs = torch.max(probs, dim=-1).values[0]
transcription = processor.decode(predicted_ids)
tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
words = []
current_word = ""
current_confs = []
prev_token = None
for tok, conf in zip(tokens, token_probs):
if tok == "<pad>":
continue
if tok == prev_token:
continue
prev_token = tok
if tok == "|":
if current_word:
words.append({
"word": current_word,
"confidence": float(sum(current_confs) / len(current_confs))
})
current_word = ""
current_confs = []
else:
current_word += tok
current_confs.append(conf.item())
if current_word:
words.append({
"word": current_word,
"confidence": float(sum(current_confs) / len(current_confs))
})
global_conf = float(token_probs.mean().item())
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
uncertainty = float(entropy.mean().item())
return JSONResponse({
"transcription": transcription,
"confidence": global_conf,
"uncertainty": uncertainty,
"words": words
})
@app.post("/lid")
async def language_identification(file: UploadFile = File(...)):
if lid_model is None or lid_extractor is None:
raise HTTPException(status_code=503, detail="LID model not loaded yet")
audio_bytes = await file.read()
try:
audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = lid_model(**inputs)
logits = outputs.logits
predicted_id = torch.argmax(logits, dim=-1).item()
predicted_lang = lid_model.config.id2label[predicted_id]
return JSONResponse({"language": predicted_lang})