pronunce-api / app.py
prakashp1893's picture
Update app.py
4aab8c8 verified
from fastapi import FastAPI, File, UploadFile
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import torchaudio
import io
import soundfile as sf
import os
from pydub import AudioSegment
# --- FINAL FIX: Use the writable /tmp directory for the cache ---
# The /code directory is read-only in Hugging Face Spaces. /tmp is writable.
CACHE_DIR = "/tmp/huggingface-cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Initialize the FastAPI app
app = FastAPI()
# --- FIX: Load model and processor using the correct cache_dir ---
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=CACHE_DIR)
model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=CACHE_DIR)
# Ensure the model is in evaluation mode
model.eval()
# Function to convert audio to the required format
def convert_audio(audio_bytes):
try:
# Load audio from bytes using pydub
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
# Set to mono
audio = audio.set_channels(1)
# Set sample rate to 16kHz
audio = audio.set_frame_rate(16000)
# Export to a buffer in WAV format
buffer = io.BytesIO()
audio.export(buffer, format="wav")
buffer.seek(0)
return buffer.read()
except Exception as e:
# This will catch errors if ffmpeg has issues with a specific file
raise ValueError(f"Error processing audio file: {e}")
@app.post("/assess-pronunciation/")
async def assess_pronunciation(audio_file: UploadFile = File(...)):
"""
This endpoint takes an audio file, converts it, and returns the recognized phonemes.
"""
# Read the audio file content
audio_bytes = await audio_file.read()
# Convert audio to the model's required format (16kHz, mono WAV)
try:
processed_audio_bytes = convert_audio(audio_bytes)
except ValueError as e:
return {"error": str(e)}
# Load the waveform from the processed audio bytes
waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
# Process the audio waveform
input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Get the predicted IDs and decode them into phonemes
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
# The output is a list with one item, so we return the item itself
return {"phoneme_transcription": transcription[0]}
@app.get("/")
def read_root():
return {"message": "Wav2Vec2 Pronunciation Assessment API is running."}