Spaces:
Runtime error
Runtime error
Commit
·
83f43dd
1
Parent(s):
7bf7549
refactor: modularize all components
Browse files- Dockerfile +1 -1
- app/__pycache__/app.cpython-39.pyc +0 -0
- app/__pycache__/main.cpython-39.pyc +0 -0
- app/main.py +49 -0
- app/models/__init__.py +0 -0
- app/models/__pycache__/__init__.cpython-39.pyc +0 -0
- app/models/__pycache__/ssl_singleton.cpython-39.pyc +0 -0
- app/models/__pycache__/transcriber_singleton.cpython-39.pyc +0 -0
- app/models/ssl_singleton.py +48 -0
- app/models/transcriber_singleton.py +45 -0
- app/modules/__init__.py +0 -0
- app/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- app/modules/pronunciation_coach/__init__.py +0 -0
- app/modules/pronunciation_coach/__pycache__/__init__.cpython-39.pyc +0 -0
- app/modules/pronunciation_coach/__pycache__/pronunciation_assessor.cpython-39.pyc +0 -0
- app/modules/pronunciation_coach/__pycache__/pronunciation_assessor_utils.cpython-39.pyc +0 -0
- app.py → app/modules/pronunciation_coach/pronunciation_assessor.py +3 -395
- app/modules/pronunciation_coach/pronunciation_assessor_utils.py +73 -0
- app/routes/__init__.py +0 -0
- app/routes/__pycache__/__init__.cpython-39.pyc +0 -0
- app/routes/__pycache__/predict.cpython-39.pyc +0 -0
- app/routes/__pycache__/transcribe.cpython-39.pyc +0 -0
- app/routes/predict.py +58 -0
- app/routes/transcribe.py +61 -0
- app/services/__init__.py +0 -0
- app/services/__pycache__/__init__.cpython-39.pyc +0 -0
- app/services/__pycache__/evaluate_pronunciation.cpython-39.pyc +0 -0
- app/services/__pycache__/transcribe.cpython-39.pyc +0 -0
- app/services/evaluate_pronunciation.py +69 -0
- app/services/transcribe.py +56 -0
- notebook-inference.ipynb → app/tester-notebook.ipynb +0 -0
- app/utils/__init__.py +0 -0
- app/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- app/utils/__pycache__/cache.cpython-39.pyc +0 -0
- app/utils/__pycache__/general_utils.cpython-39.pyc +0 -0
- app/utils/cache.py +48 -0
- app/utils/general_utils.py +73 -0
- inference.py +0 -214
Dockerfile
CHANGED
|
@@ -29,4 +29,4 @@ COPY . .
|
|
| 29 |
EXPOSE 7860
|
| 30 |
|
| 31 |
# Run the FastAPI application
|
| 32 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 29 |
EXPOSE 7860
|
| 30 |
|
| 31 |
# Run the FastAPI application
|
| 32 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__pycache__/app.cpython-39.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
app/__pycache__/main.cpython-39.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
app/main.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
import uvicorn
|
| 5 |
+
from typing import List
|
| 6 |
+
import torch
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 9 |
+
import re
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cmudict
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import os
|
| 14 |
+
import logging
|
| 15 |
+
from joblib import Memory
|
| 16 |
+
from difflib import SequenceMatcher
|
| 17 |
+
import eng_to_ipa as ipa_conv
|
| 18 |
+
import os
|
| 19 |
+
import copy
|
| 20 |
+
from IPython.display import HTML, display
|
| 21 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 22 |
+
from pydub import AudioSegment
|
| 23 |
+
from Bio import pairwise2
|
| 24 |
+
from Bio.pairwise2 import format_alignment
|
| 25 |
+
import asyncio
|
| 26 |
+
from cachetools import TTLCache
|
| 27 |
+
|
| 28 |
+
# Set the Numba cache directory to a writable location
|
| 29 |
+
os.environ["NUMBA_CACHE_DIR"] = "/tmp"
|
| 30 |
+
import librosa
|
| 31 |
+
logging.basicConfig(level=logging.INFO)
|
| 32 |
+
|
| 33 |
+
# package imports
|
| 34 |
+
from routes.transcribe import router as transcriber_router
|
| 35 |
+
from routes.predict import router as pronunciation_evaluation_router
|
| 36 |
+
# Initialize FastAPI app
|
| 37 |
+
app = FastAPI(title="Talkiee AI", version="1.0.0")
|
| 38 |
+
|
| 39 |
+
# health check
|
| 40 |
+
@app.get("/")
|
| 41 |
+
def home():
|
| 42 |
+
return "Healthy bro!"
|
| 43 |
+
|
| 44 |
+
app.include_router(transcriber_router, tags=["transcribe"])
|
| 45 |
+
app.include_router(pronunciation_evaluation_router, tags=["pronunciation_evaluation"])
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
port = os.environ.get("PORT", 10000) # Default to 10000 if PORT is not set
|
| 48 |
+
logging.info(f"Starting server on PORT {port}")
|
| 49 |
+
uvicorn.run("main:app", host="0.0.0.0", port=int(port), log_level="info")
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
app/models/__pycache__/ssl_singleton.cpython-39.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
app/models/__pycache__/transcriber_singleton.cpython-39.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
app/models/ssl_singleton.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 4 |
+
from utils.general_utils import process_audio
|
| 5 |
+
import asyncio
|
| 6 |
+
import librosa
|
| 7 |
+
from utils.cache import audio_cache
|
| 8 |
+
|
| 9 |
+
class SSLSingleton:
|
| 10 |
+
_instance = None
|
| 11 |
+
|
| 12 |
+
def __new__(cls, model_name="mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme", device=None):
|
| 13 |
+
if cls._instance is None:
|
| 14 |
+
cls._instance = super(SSLSingleton, cls).__new__(cls)
|
| 15 |
+
cls._instance._initialize(model_name, device)
|
| 16 |
+
return cls._instance
|
| 17 |
+
|
| 18 |
+
def _initialize(self, model_name, device):
|
| 19 |
+
# Set device (CPU or GPU)
|
| 20 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
self.device = "cpu"
|
| 22 |
+
# Load processor and model
|
| 23 |
+
print("Loading SSL processor and model...") # This will only happen once
|
| 24 |
+
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
| 25 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
| 26 |
+
self.model.eval()
|
| 27 |
+
self.model.to(self.device) # Move model to the specified device
|
| 28 |
+
|
| 29 |
+
# an infernce function taking in processed audio input and returning the predictions
|
| 30 |
+
def infer(self, audio_input, device):
|
| 31 |
+
inputs = self.processor(audio_input, sampling_rate=16000, return_tensors="pt")
|
| 32 |
+
inputs = inputs.to(self.device)
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
logits = self.model(inputs.input_values).logits
|
| 36 |
+
|
| 37 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 38 |
+
uttered_phonemes = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 39 |
+
return uttered_phonemes
|
| 40 |
+
|
| 41 |
+
async def infer_and_save_to_cache(self, file_name, audio_input, device):
|
| 42 |
+
uttered_phonemes = self.infer(audio_input, device)
|
| 43 |
+
async with audio_cache.lock:
|
| 44 |
+
new_cache = audio_cache.cache[file_name]
|
| 45 |
+
new_cache["uttered_phonemes"] = uttered_phonemes
|
| 46 |
+
audio_cache.cache[file_name] = new_cache
|
| 47 |
+
return uttered_phonemes
|
| 48 |
+
ssl_model = SSLSingleton()
|
app/models/transcriber_singleton.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 4 |
+
from utils.general_utils import process_audio
|
| 5 |
+
import asyncio
|
| 6 |
+
import librosa
|
| 7 |
+
|
| 8 |
+
class TranscriberSingleton:
|
| 9 |
+
_instance = None
|
| 10 |
+
|
| 11 |
+
def __new__(cls, model_name="openai/whisper-tiny.en", device=None):
|
| 12 |
+
if cls._instance is None:
|
| 13 |
+
cls._instance = super(TranscriberSingleton, cls).__new__(cls)
|
| 14 |
+
cls._instance._initialize(model_name, device)
|
| 15 |
+
return cls._instance
|
| 16 |
+
|
| 17 |
+
def _initialize(self, model_name, device):
|
| 18 |
+
# Set device (CPU or GPU)
|
| 19 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
self.device = "cpu"
|
| 21 |
+
# Load processor and model
|
| 22 |
+
print(f"Loading Whisper processor and model into {device}...") # This will only happen once
|
| 23 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 24 |
+
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name)
|
| 25 |
+
self.model.eval()
|
| 26 |
+
self.model.to(self.device) # Move model to the specified device
|
| 27 |
+
|
| 28 |
+
def transcribe_into_English(self, audio_input):
|
| 29 |
+
# Load audio file
|
| 30 |
+
audio_input = self.processor(audio_input, sampling_rate=16000, return_tensors="pt", language="en").to(self.device)
|
| 31 |
+
|
| 32 |
+
# Perform transcription
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
generated_ids = self.model.generate(audio_input.input_features)
|
| 35 |
+
|
| 36 |
+
# Decode the transcription
|
| 37 |
+
transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 38 |
+
return transcription.lower().strip()
|
| 39 |
+
|
| 40 |
+
def transcribe_from_file_path(self, file_path, target_sr=16000):
|
| 41 |
+
with open(file_path, "rb") as f:
|
| 42 |
+
audio_input, sr = librosa.load(file_path, sr=target_sr)
|
| 43 |
+
return self.transcribe_into_English(audio_input)
|
| 44 |
+
|
| 45 |
+
transcriber_model = TranscriberSingleton()
|
app/modules/__init__.py
ADDED
|
File without changes
|
app/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
app/modules/pronunciation_coach/__init__.py
ADDED
|
File without changes
|
app/modules/pronunciation_coach/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
app/modules/pronunciation_coach/__pycache__/pronunciation_assessor.cpython-39.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
app/modules/pronunciation_coach/__pycache__/pronunciation_assessor_utils.cpython-39.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
app.py → app/modules/pronunciation_coach/pronunciation_assessor.py
RENAMED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
from fastapi import FastAPI, UploadFile, Form, HTTPException
|
| 3 |
-
from fastapi.responses import JSONResponse
|
| 4 |
-
import uvicorn
|
| 5 |
from typing import List
|
| 6 |
import torch
|
| 7 |
import soundfile as sf
|
|
@@ -10,7 +6,6 @@ import re
|
|
| 10 |
import numpy as np
|
| 11 |
import cmudict
|
| 12 |
from io import BytesIO
|
| 13 |
-
import os
|
| 14 |
import logging
|
| 15 |
from joblib import Memory
|
| 16 |
from difflib import SequenceMatcher
|
|
@@ -24,267 +19,10 @@ from Bio import pairwise2
|
|
| 24 |
from Bio.pairwise2 import format_alignment
|
| 25 |
import asyncio
|
| 26 |
from cachetools import TTLCache
|
| 27 |
-
|
| 28 |
-
# Set the Numba cache directory to a writable location
|
| 29 |
-
os.environ["NUMBA_CACHE_DIR"] = "/tmp"
|
| 30 |
-
import librosa
|
| 31 |
-
|
| 32 |
-
logging.basicConfig(level=logging.INFO)
|
| 33 |
-
|
| 34 |
-
cmu = cmudict.dict()
|
| 35 |
-
|
| 36 |
-
# Initialize FastAPI app
|
| 37 |
-
app = FastAPI()
|
| 38 |
-
|
| 39 |
-
# Load the processor and model
|
| 40 |
-
MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
|
| 41 |
-
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
| 42 |
-
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
|
| 43 |
-
model.eval()
|
| 44 |
-
|
| 45 |
-
# Check device availability
|
| 46 |
-
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
-
device = 'cpu' # TEMP for testing
|
| 48 |
-
model.to(device)
|
| 49 |
-
|
| 50 |
-
whisper_processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
| 51 |
-
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-tiny.en")
|
| 52 |
-
whisper_model.eval()
|
| 53 |
-
whisper_model.to(device)
|
| 54 |
-
|
| 55 |
-
# =====================================
|
| 56 |
-
# Section: Utils
|
| 57 |
-
# =====================================
|
| 58 |
-
|
| 59 |
-
# Initialize a cache with a 5-minute TTL and 100 items max
|
| 60 |
-
audio_cache = TTLCache(maxsize=100, ttl=300)
|
| 61 |
-
cache_lock = asyncio.Lock() # To prevent race conditions
|
| 62 |
-
|
| 63 |
-
import os
|
| 64 |
-
from tempfile import NamedTemporaryFile
|
| 65 |
-
import subprocess
|
| 66 |
-
|
| 67 |
-
async def process_audio(audio, device):
|
| 68 |
-
"""
|
| 69 |
-
Process an uploaded audio file and prepare input for the model.
|
| 70 |
-
Converts audio to WAV format using FFmpeg prior to processing.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
audio: The uploaded audio file.
|
| 74 |
-
device: The device (e.g., 'cuda' or 'cpu') to move tensors to.
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
cache_entry: A dictionary containing processed audio and model input.
|
| 78 |
-
"""
|
| 79 |
-
filename = audio.filename
|
| 80 |
-
|
| 81 |
-
# Check cache for processed audio
|
| 82 |
-
if filename in audio_cache:
|
| 83 |
-
logging.info(f"Audio '{filename}' found in cache.")
|
| 84 |
-
return audio_cache[filename]
|
| 85 |
-
|
| 86 |
-
async with cache_lock: # Prevent race conditions during cache writes
|
| 87 |
-
if filename in audio_cache: # Double-check after acquiring lock
|
| 88 |
-
logging.info(f"Audio '{filename}' found in cache after lock.")
|
| 89 |
-
return audio_cache[filename]
|
| 90 |
-
|
| 91 |
-
logging.info(f"Processing audio '{filename}'.")
|
| 92 |
-
|
| 93 |
-
# Read the audio file into a temporary file
|
| 94 |
-
with NamedTemporaryFile(delete=False, suffix=".m4a") as temp_m4a:
|
| 95 |
-
temp_m4a_path = temp_m4a.name
|
| 96 |
-
temp_m4a.write(await audio.read())
|
| 97 |
-
|
| 98 |
-
# Convert M4A to WAV using FFmpeg
|
| 99 |
-
temp_wav_path = temp_m4a_path.replace(".m4a", ".wav")
|
| 100 |
-
try:
|
| 101 |
-
subprocess.run(
|
| 102 |
-
[
|
| 103 |
-
"ffmpeg", "-i", temp_m4a_path, # Input file
|
| 104 |
-
"-ar", "16000", # Resample to 16kHz
|
| 105 |
-
"-ac", "1", # Convert to mono
|
| 106 |
-
temp_wav_path # Output file
|
| 107 |
-
],
|
| 108 |
-
check=True,
|
| 109 |
-
stdout=subprocess.PIPE,
|
| 110 |
-
stderr=subprocess.PIPE
|
| 111 |
-
)
|
| 112 |
-
except subprocess.CalledProcessError as e:
|
| 113 |
-
logging.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
| 114 |
-
raise HTTPException(status_code=500, detail="Failed to process audio file.")
|
| 115 |
-
finally:
|
| 116 |
-
os.remove(temp_m4a_path) # Clean up the temporary M4A file
|
| 117 |
-
|
| 118 |
-
try:
|
| 119 |
-
# Load the WAV audio for further processing
|
| 120 |
-
audio_segment = AudioSegment.from_file(temp_wav_path, format="wav")
|
| 121 |
-
audio_samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
|
| 122 |
-
max_val = np.iinfo(np.int16).max
|
| 123 |
-
audio_samples /= max_val
|
| 124 |
-
|
| 125 |
-
if audio_segment.channels > 1:
|
| 126 |
-
audio_samples = audio_samples.reshape(-1, audio_segment.channels).mean(axis=1)
|
| 127 |
-
|
| 128 |
-
audio_input = librosa.resample(audio_samples, orig_sr=audio_segment.frame_rate, target_sr=16000)
|
| 129 |
-
input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values.to(device)
|
| 130 |
-
|
| 131 |
-
# Cache the processed audio
|
| 132 |
-
cache_entry = {"audio_input": audio_input, "input_values": input_values, "ssl_logits": None}
|
| 133 |
-
audio_cache[filename] = cache_entry
|
| 134 |
-
return cache_entry
|
| 135 |
-
|
| 136 |
-
finally:
|
| 137 |
-
# Clean up the temporary WAV file
|
| 138 |
-
os.remove(temp_wav_path)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
async def run_ssl_inference(filename, input_values):
|
| 142 |
-
"""
|
| 143 |
-
Run SSL model inference in the background and store the results in the cache.
|
| 144 |
-
|
| 145 |
-
Args:
|
| 146 |
-
filename: The name of the audio file.
|
| 147 |
-
input_values: The processed input tensor for the SSL model.
|
| 148 |
-
"""
|
| 149 |
-
try:
|
| 150 |
-
logging.info(f"Running SSL inference for '{filename}' in the background.")
|
| 151 |
-
with torch.no_grad():
|
| 152 |
-
ssl_output = model(input_values).logits
|
| 153 |
-
|
| 154 |
-
# Update the cache with the SSL inference result
|
| 155 |
-
if filename in audio_cache:
|
| 156 |
-
audio_cache[filename]["ssl_logits"] = ssl_output
|
| 157 |
-
logging.info(f"SSL inference for '{filename}' completed and cached.")
|
| 158 |
-
except Exception as e:
|
| 159 |
-
logging.error(f"Error during SSL inference for '{filename}': {e}")
|
| 160 |
-
|
| 161 |
-
def transcribe_into_English(audio_input):
|
| 162 |
-
# Load audio file
|
| 163 |
-
audio_input = whisper_processor(audio_input, sampling_rate=16000, return_tensors="pt", language="en").to(device)
|
| 164 |
-
|
| 165 |
-
# Perform transcription
|
| 166 |
-
with torch.no_grad():
|
| 167 |
-
generated_ids = whisper_model.generate(audio_input.input_features)
|
| 168 |
-
|
| 169 |
-
# Decode the transcription
|
| 170 |
-
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 171 |
-
return transcription.lower().strip()
|
| 172 |
-
|
| 173 |
-
def get_nested_position(nested_list, flat_index):
|
| 174 |
-
"""
|
| 175 |
-
Finds the nested list and the index within it for a given flat index.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
nested_list (list of lists): The list of lists.
|
| 179 |
-
flat_index (int): The flattened index.
|
| 180 |
-
|
| 181 |
-
Returns:
|
| 182 |
-
tuple: (nested_list_index, element_index_in_nested_list)
|
| 183 |
-
"""
|
| 184 |
-
cumulative_index = 0
|
| 185 |
-
|
| 186 |
-
for list_index, sublist in enumerate(nested_list):
|
| 187 |
-
# Check if the flat index falls within the current sublist
|
| 188 |
-
if cumulative_index + len(sublist) > flat_index:
|
| 189 |
-
# Calculate the index within the sublist
|
| 190 |
-
element_index = flat_index - cumulative_index
|
| 191 |
-
return list_index, element_index
|
| 192 |
-
# Update cumulative index
|
| 193 |
-
cumulative_index += len(sublist)
|
| 194 |
-
|
| 195 |
-
raise IndexError("Index out of range for the flattened list.")
|
| 196 |
-
|
| 197 |
-
def label_specific_elements_in_reference(reference, start_word_idx, start_element_idx, end_word_idx, end_element_idx, label):
|
| 198 |
-
"""
|
| 199 |
-
Labels elements in a nested list between specified start and end indices (inclusive).
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
reference (list of lists): The original list of lists.
|
| 203 |
-
start_word_idx (int): Index of the starting nested list.
|
| 204 |
-
start_element_idx (int): Index of the starting element in the start list.
|
| 205 |
-
end_word_idx (int): Index of the ending nested list.
|
| 206 |
-
end_element_idx (int): Index of the ending element in the end list.
|
| 207 |
-
label: The label to attach to the elements.
|
| 208 |
-
|
| 209 |
-
Returns:
|
| 210 |
-
list of lists: A new list of lists with labels attached where applicable.
|
| 211 |
-
"""
|
| 212 |
-
labeled_reference = []
|
| 213 |
-
for word_idx, sublist in enumerate(reference):
|
| 214 |
-
labeled_sublist = []
|
| 215 |
-
|
| 216 |
-
for element_idx, element in enumerate(sublist):
|
| 217 |
-
if start_word_idx < end_word_idx:
|
| 218 |
-
# Case 1: start_word_idx < end_word_idx
|
| 219 |
-
if (
|
| 220 |
-
(word_idx > start_word_idx and word_idx < end_word_idx) or
|
| 221 |
-
(word_idx == start_word_idx and element_idx >= start_element_idx) or
|
| 222 |
-
(word_idx == end_word_idx and element_idx <= end_element_idx)
|
| 223 |
-
):
|
| 224 |
-
# Attach the label to elements within the inclusive range
|
| 225 |
-
if isinstance(element, tuple):
|
| 226 |
-
print(f"There is already a label at index ({word_idx}, {element_idx})")
|
| 227 |
-
labeled_sublist.append((element, label))
|
| 228 |
-
else:
|
| 229 |
-
# Keep elements outside the range unchanged
|
| 230 |
-
labeled_sublist.append(element)
|
| 231 |
-
elif start_word_idx == end_word_idx:
|
| 232 |
-
# Case 2: start_word_idx == end_word_idx
|
| 233 |
-
if word_idx == start_word_idx and start_element_idx <= element_idx <= end_element_idx:
|
| 234 |
-
# Attach the label to elements within the inclusive range
|
| 235 |
-
if isinstance(element, tuple):
|
| 236 |
-
print(f"There is already a label at index ({word_idx}, {element_idx})")
|
| 237 |
-
labeled_sublist.append((element, label))
|
| 238 |
-
else:
|
| 239 |
-
# Keep elements outside the range unchanged
|
| 240 |
-
labeled_sublist.append(element)
|
| 241 |
-
|
| 242 |
-
labeled_reference.append(labeled_sublist)
|
| 243 |
-
|
| 244 |
-
return labeled_reference
|
| 245 |
-
|
| 246 |
-
def clean_text(text: str) -> str:
|
| 247 |
-
"""
|
| 248 |
-
Remove punctuation from the input string except for special characters
|
| 249 |
-
that are part of a word, such as ' in I'm or - in hard-working.
|
| 250 |
-
|
| 251 |
-
Parameters:
|
| 252 |
-
text (str): Input string to clean.
|
| 253 |
-
|
| 254 |
-
Returns:
|
| 255 |
-
str: Cleaned string with allowed special characters retained.
|
| 256 |
-
"""
|
| 257 |
-
# Allow letters, spaces, apostrophes, and hyphens within words
|
| 258 |
-
cleaned_text = re.sub(r'[^\w\s\'-]', '', text) # Remove punctuation except ' and -
|
| 259 |
-
cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Normalize spaces
|
| 260 |
-
return cleaned_text.lower().strip()
|
| 261 |
-
|
| 262 |
-
# =====================================
|
| 263 |
-
# Section: IPA Phonemes Utils
|
| 264 |
-
# =====================================
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
# WORKING: converting functions to class, currently done with the last function in the class
|
| 268 |
-
import re
|
| 269 |
-
from difflib import SequenceMatcher
|
| 270 |
-
from IPython.display import HTML, display
|
| 271 |
-
import copy
|
| 272 |
-
from IPython.display import HTML, display
|
| 273 |
-
from Bio import pairwise2
|
| 274 |
-
from Bio.pairwise2 import format_alignment
|
| 275 |
-
|
| 276 |
-
# WORKING: converting functions to class, currently done with the last function in the class
|
| 277 |
-
import re
|
| 278 |
-
from difflib import SequenceMatcher
|
| 279 |
-
from IPython.display import HTML, display
|
| 280 |
-
import copy
|
| 281 |
-
from IPython.display import HTML, display
|
| 282 |
-
from Bio import pairwise2
|
| 283 |
-
from Bio.pairwise2 import format_alignment
|
| 284 |
-
import cmudict
|
| 285 |
cmu_dict = cmudict.dict()
|
| 286 |
|
| 287 |
-
class
|
| 288 |
def __init__(self, transcript, uttered_phonemes):
|
| 289 |
# NOTE: removed all long signals ('ː') for compatibility with L2-artic's phoneme set (ssl model training set). American English.
|
| 290 |
# ground truth phonemes are converted into arpabet first, and then into ipa using the arpabet_to_ipa dict, meaning the arpabet_to_ipa dict contains
|
|
@@ -1159,134 +897,4 @@ class PronunciationAssessment:
|
|
| 1159 |
|
| 1160 |
# Display
|
| 1161 |
display(HTML(html_content))
|
| 1162 |
-
|
| 1163 |
-
# health check
|
| 1164 |
-
@app.get("/")
|
| 1165 |
-
def home():
|
| 1166 |
-
return "Healthy bro!"
|
| 1167 |
-
|
| 1168 |
-
import time # temp
|
| 1169 |
-
|
| 1170 |
-
# taking in both audio and transcript from the user
|
| 1171 |
-
@app.post("/predict")
|
| 1172 |
-
async def predict(audio: UploadFile, transcript: str = Form(...)):
|
| 1173 |
-
"""
|
| 1174 |
-
Predict phoneme labels from uploaded audio and provided transcript.
|
| 1175 |
-
|
| 1176 |
-
Args:
|
| 1177 |
-
audio (UploadFile): Uploaded audio file (WAV/MP3).
|
| 1178 |
-
transcript (str): Ground truth transcript.
|
| 1179 |
-
|
| 1180 |
-
Returns:
|
| 1181 |
-
JSONResponse: Contains phoneme labels.
|
| 1182 |
-
"""
|
| 1183 |
-
logging.info("Received prediction request!")
|
| 1184 |
-
|
| 1185 |
-
# Validate file extension
|
| 1186 |
-
allowed_extensions = {"wav", "mp3", "m4a"}
|
| 1187 |
-
filename = audio.filename.lower()
|
| 1188 |
-
start_time = time.time()
|
| 1189 |
-
|
| 1190 |
-
if not filename.endswith(tuple(allowed_extensions)):
|
| 1191 |
-
raise HTTPException(
|
| 1192 |
-
status_code=400,
|
| 1193 |
-
detail="Invalid file type. Only WAV and MP3 files are supported.",
|
| 1194 |
-
)
|
| 1195 |
-
|
| 1196 |
-
# Load and preprocess the audio
|
| 1197 |
-
try:
|
| 1198 |
-
cache_entry = await process_audio(audio, device)
|
| 1199 |
-
input_values = cache_entry["input_values"]
|
| 1200 |
-
|
| 1201 |
-
# Ensure SSL inference is completed
|
| 1202 |
-
logits = cache_entry.get("ssl_logits")
|
| 1203 |
-
if logits is None:
|
| 1204 |
-
logging.info(f"SSL inference not cached for '{filename}', running now.")
|
| 1205 |
-
with torch.no_grad():
|
| 1206 |
-
logits = model(input_values).logits
|
| 1207 |
-
cache_entry["ssl_logits"] = logits
|
| 1208 |
-
|
| 1209 |
-
end_time = time.time()
|
| 1210 |
-
print(f"Time from call to finish processing audio: {end_time - start_time} seconds")
|
| 1211 |
-
|
| 1212 |
-
start_time = time.time()
|
| 1213 |
-
transcript = clean_text(transcript).strip()
|
| 1214 |
-
|
| 1215 |
-
# Decode the phonemes
|
| 1216 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
| 1217 |
-
uttered_phonemes = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 1218 |
-
end_time = time.time()
|
| 1219 |
-
print("Time taken for inference:", end_time - start_time)
|
| 1220 |
-
|
| 1221 |
-
start_time = time.time()
|
| 1222 |
-
# init PronunciationAssessment instance
|
| 1223 |
-
cur = PronunciationAssessment(transcript, uttered_phonemes)
|
| 1224 |
-
cur.convert_transcript_into_phonemes()
|
| 1225 |
-
cur.clean_ipa_phonemes()
|
| 1226 |
-
cur.split_phoneme_sequence()
|
| 1227 |
-
print(cur.uttered_ipa_phonemes)
|
| 1228 |
-
# print(cur.segmented_ground_truth_ipa_phonemes)
|
| 1229 |
-
# print(cur.segmented_uttered_ipa_phonemes)
|
| 1230 |
-
|
| 1231 |
-
# generate the final labels
|
| 1232 |
-
labels = cur.generate_labels_for_api()
|
| 1233 |
-
end_time = time.time()
|
| 1234 |
-
print("Time taken for label generation:", end_time - start_time)
|
| 1235 |
-
return JSONResponse(content={"labels": labels})
|
| 1236 |
-
|
| 1237 |
-
except Exception as e:
|
| 1238 |
-
logging.error(f"Error during prediction: {e}")
|
| 1239 |
-
raise HTTPException(status_code=500, detail="An error occurred during processing.")
|
| 1240 |
-
|
| 1241 |
-
# taking in audio only and returning the transcript
|
| 1242 |
-
@app.post("/transcribe")
|
| 1243 |
-
async def transcribe(audio: UploadFile):
|
| 1244 |
-
"""
|
| 1245 |
-
Transcribe the uploaded audio and return the transcript.
|
| 1246 |
-
|
| 1247 |
-
Args:
|
| 1248 |
-
audio (UploadFile): Uploaded audio file (WAV/MP3).
|
| 1249 |
-
|
| 1250 |
-
Returns:
|
| 1251 |
-
JSONResponse: Contains the transcript.
|
| 1252 |
-
"""
|
| 1253 |
-
logging.info("Received transcription request!")
|
| 1254 |
-
|
| 1255 |
-
# Validate file extension
|
| 1256 |
-
allowed_extensions = {"wav", "mp3", "m4a"}
|
| 1257 |
-
filename = audio.filename.lower()
|
| 1258 |
-
if not filename.endswith(tuple(allowed_extensions)):
|
| 1259 |
-
raise HTTPException(
|
| 1260 |
-
status_code=400,
|
| 1261 |
-
detail="Invalid file type. Only WAV and MP3 files are supported.",
|
| 1262 |
-
)
|
| 1263 |
-
|
| 1264 |
-
# Load and preprocess the audio
|
| 1265 |
-
try:
|
| 1266 |
-
# Process the audio
|
| 1267 |
-
start_time = time.time()
|
| 1268 |
-
cache_entry = await process_audio(audio, device)
|
| 1269 |
-
audio_input = cache_entry["audio_input"]
|
| 1270 |
-
input_values = cache_entry["input_values"]
|
| 1271 |
-
|
| 1272 |
-
# Start SSL inference in the background
|
| 1273 |
-
asyncio.create_task(run_ssl_inference(audio.filename, input_values))
|
| 1274 |
-
|
| 1275 |
-
# Get transcript
|
| 1276 |
-
end_time = time.time()
|
| 1277 |
-
print(f"Time from call to finish processing audio: {end_time - start_time} seconds")
|
| 1278 |
-
transcript = transcribe_into_English(audio_input)
|
| 1279 |
-
transcript = clean_text(transcript).strip()
|
| 1280 |
-
another_end_time = time.time()
|
| 1281 |
-
logging.info(f"Transcript: {transcript}, Time taken from processed audio to finish transcription: {another_end_time - end_time} seconds")
|
| 1282 |
-
|
| 1283 |
-
return JSONResponse(content={"transcript": transcript})
|
| 1284 |
-
|
| 1285 |
-
except Exception as e:
|
| 1286 |
-
logging.error(f"Error during transcription: {e}")
|
| 1287 |
-
raise HTTPException(status_code=500, detail="An error occurred during processing.")
|
| 1288 |
-
|
| 1289 |
-
# if __name__ == '__main__':
|
| 1290 |
-
# port = os.environ.get("PORT", 10000) # Default to 10000 if PORT is not set
|
| 1291 |
-
# logging.info(f"Starting server on PORT {port}")
|
| 1292 |
-
# uvicorn.run("app:app", host="0.0.0.0", port=int(port), log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
import torch
|
| 3 |
import soundfile as sf
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import cmudict
|
| 8 |
from io import BytesIO
|
|
|
|
| 9 |
import logging
|
| 10 |
from joblib import Memory
|
| 11 |
from difflib import SequenceMatcher
|
|
|
|
| 19 |
from Bio.pairwise2 import format_alignment
|
| 20 |
import asyncio
|
| 21 |
from cachetools import TTLCache
|
| 22 |
+
from modules.pronunciation_coach.pronunciation_assessor_utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
cmu_dict = cmudict.dict()
|
| 24 |
|
| 25 |
+
class PronunciationAssessor:
|
| 26 |
def __init__(self, transcript, uttered_phonemes):
|
| 27 |
# NOTE: removed all long signals ('ː') for compatibility with L2-artic's phoneme set (ssl model training set). American English.
|
| 28 |
# ground truth phonemes are converted into arpabet first, and then into ipa using the arpabet_to_ipa dict, meaning the arpabet_to_ipa dict contains
|
|
|
|
| 897 |
|
| 898 |
# Display
|
| 899 |
display(HTML(html_content))
|
| 900 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/modules/pronunciation_coach/pronunciation_assessor_utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
def get_nested_position(nested_list, flat_index):
|
| 3 |
+
"""
|
| 4 |
+
Finds the nested list and the index within it for a given flat index.
|
| 5 |
+
|
| 6 |
+
Args:
|
| 7 |
+
nested_list (list of lists): The list of lists.
|
| 8 |
+
flat_index (int): The flattened index.
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
tuple: (nested_list_index, element_index_in_nested_list)
|
| 12 |
+
"""
|
| 13 |
+
cumulative_index = 0
|
| 14 |
+
|
| 15 |
+
for list_index, sublist in enumerate(nested_list):
|
| 16 |
+
# Check if the flat index falls within the current sublist
|
| 17 |
+
if cumulative_index + len(sublist) > flat_index:
|
| 18 |
+
# Calculate the index within the sublist
|
| 19 |
+
element_index = flat_index - cumulative_index
|
| 20 |
+
return list_index, element_index
|
| 21 |
+
# Update cumulative index
|
| 22 |
+
cumulative_index += len(sublist)
|
| 23 |
+
|
| 24 |
+
raise IndexError("Index out of range for the flattened list.")
|
| 25 |
+
|
| 26 |
+
def label_specific_elements_in_reference(reference, start_word_idx, start_element_idx, end_word_idx, end_element_idx, label):
|
| 27 |
+
"""
|
| 28 |
+
Labels elements in a nested list between specified start and end indices (inclusive).
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
reference (list of lists): The original list of lists.
|
| 32 |
+
start_word_idx (int): Index of the starting nested list.
|
| 33 |
+
start_element_idx (int): Index of the starting element in the start list.
|
| 34 |
+
end_word_idx (int): Index of the ending nested list.
|
| 35 |
+
end_element_idx (int): Index of the ending element in the end list.
|
| 36 |
+
label: The label to attach to the elements.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
list of lists: A new list of lists with labels attached where applicable.
|
| 40 |
+
"""
|
| 41 |
+
labeled_reference = []
|
| 42 |
+
for word_idx, sublist in enumerate(reference):
|
| 43 |
+
labeled_sublist = []
|
| 44 |
+
|
| 45 |
+
for element_idx, element in enumerate(sublist):
|
| 46 |
+
if start_word_idx < end_word_idx:
|
| 47 |
+
# Case 1: start_word_idx < end_word_idx
|
| 48 |
+
if (
|
| 49 |
+
(word_idx > start_word_idx and word_idx < end_word_idx) or
|
| 50 |
+
(word_idx == start_word_idx and element_idx >= start_element_idx) or
|
| 51 |
+
(word_idx == end_word_idx and element_idx <= end_element_idx)
|
| 52 |
+
):
|
| 53 |
+
# Attach the label to elements within the inclusive range
|
| 54 |
+
if isinstance(element, tuple):
|
| 55 |
+
print(f"There is already a label at index ({word_idx}, {element_idx})")
|
| 56 |
+
labeled_sublist.append((element, label))
|
| 57 |
+
else:
|
| 58 |
+
# Keep elements outside the range unchanged
|
| 59 |
+
labeled_sublist.append(element)
|
| 60 |
+
elif start_word_idx == end_word_idx:
|
| 61 |
+
# Case 2: start_word_idx == end_word_idx
|
| 62 |
+
if word_idx == start_word_idx and start_element_idx <= element_idx <= end_element_idx:
|
| 63 |
+
# Attach the label to elements within the inclusive range
|
| 64 |
+
if isinstance(element, tuple):
|
| 65 |
+
print(f"There is already a label at index ({word_idx}, {element_idx})")
|
| 66 |
+
labeled_sublist.append((element, label))
|
| 67 |
+
else:
|
| 68 |
+
# Keep elements outside the range unchanged
|
| 69 |
+
labeled_sublist.append(element)
|
| 70 |
+
|
| 71 |
+
labeled_reference.append(labeled_sublist)
|
| 72 |
+
|
| 73 |
+
return labeled_reference
|
app/routes/__init__.py
ADDED
|
File without changes
|
app/routes/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
app/routes/__pycache__/predict.cpython-39.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
app/routes/__pycache__/transcribe.cpython-39.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
app/routes/predict.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException, APIRouter, Depends
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
import uvicorn
|
| 4 |
+
from typing import List
|
| 5 |
+
import torch
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cmudict
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
import logging
|
| 13 |
+
from joblib import Memory
|
| 14 |
+
from difflib import SequenceMatcher
|
| 15 |
+
import eng_to_ipa as ipa_conv
|
| 16 |
+
import copy
|
| 17 |
+
from IPython.display import HTML, display
|
| 18 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 19 |
+
from pydub import AudioSegment
|
| 20 |
+
from Bio import pairwise2
|
| 21 |
+
from Bio.pairwise2 import format_alignment
|
| 22 |
+
import asyncio
|
| 23 |
+
from cachetools import TTLCache
|
| 24 |
+
import time
|
| 25 |
+
import os
|
| 26 |
+
from tempfile import NamedTemporaryFile
|
| 27 |
+
import subprocess
|
| 28 |
+
import librosa
|
| 29 |
+
|
| 30 |
+
# package imports
|
| 31 |
+
from services.evaluate_pronunciation import PronunciationEvalService
|
| 32 |
+
from utils.general_utils import clean_text
|
| 33 |
+
|
| 34 |
+
router = APIRouter()
|
| 35 |
+
|
| 36 |
+
@router.post("/predict", summary="Evaluate pronunciation")
|
| 37 |
+
async def evaluate_pronunciation(audio: UploadFile, transcript: str = Form(...)):
|
| 38 |
+
"""
|
| 39 |
+
Predict phoneme labels from uploaded audio and provided transcript.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
audio (UploadFile): Uploaded audio file (WAV/MP3).
|
| 43 |
+
transcript (str): Ground truth transcript.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
JSONResponse: Contains phoneme labels.
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# Call the service to process and transcribe the audio
|
| 50 |
+
service = PronunciationEvalService(transcript, audio)
|
| 51 |
+
labels = await service.generate_labels()
|
| 52 |
+
|
| 53 |
+
response = {'labels': labels}
|
| 54 |
+
return JSONResponse(content=response)
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logging.error(f"Error during evaluation: {e}")
|
| 58 |
+
raise HTTPException(status_code=500, detail="An error occurred during processing.")
|
app/routes/transcribe.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException, APIRouter, Depends
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
import uvicorn
|
| 4 |
+
from typing import List
|
| 5 |
+
import torch
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cmudict
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
import logging
|
| 13 |
+
from joblib import Memory
|
| 14 |
+
from difflib import SequenceMatcher
|
| 15 |
+
import eng_to_ipa as ipa_conv
|
| 16 |
+
import copy
|
| 17 |
+
from IPython.display import HTML, display
|
| 18 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 19 |
+
from pydub import AudioSegment
|
| 20 |
+
from Bio import pairwise2
|
| 21 |
+
from Bio.pairwise2 import format_alignment
|
| 22 |
+
import asyncio
|
| 23 |
+
from cachetools import TTLCache
|
| 24 |
+
import time
|
| 25 |
+
import os
|
| 26 |
+
from tempfile import NamedTemporaryFile
|
| 27 |
+
import subprocess
|
| 28 |
+
import librosa
|
| 29 |
+
|
| 30 |
+
# package imports
|
| 31 |
+
from services.transcribe import TranscriptionService
|
| 32 |
+
from utils.general_utils import clean_text
|
| 33 |
+
|
| 34 |
+
router = APIRouter()
|
| 35 |
+
|
| 36 |
+
service = TranscriptionService()
|
| 37 |
+
@router.post("/transcribe", summary="Trancribe audio into English")
|
| 38 |
+
async def transcribe(audio: UploadFile):
|
| 39 |
+
"""
|
| 40 |
+
Transcribe the uploaded audio and return the transcript.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
audio (UploadFile): Uploaded audio file.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
JSONResponse: Contains the transcript.
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# Call the service to process and transcribe the audio
|
| 50 |
+
transcript = await service.transcribe_audio(audio)
|
| 51 |
+
transcript = clean_text(transcript).strip()
|
| 52 |
+
|
| 53 |
+
response = {'transcript': transcript}
|
| 54 |
+
return JSONResponse(content=response)
|
| 55 |
+
|
| 56 |
+
except ValueError as ve:
|
| 57 |
+
logging.error(f"Validation error: {ve}")
|
| 58 |
+
raise HTTPException(status_code=400, detail=str(ve))
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logging.error(f"Error during transcription: {e}")
|
| 61 |
+
raise HTTPException(status_code=500, detail="An error occurred during processing.")
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
app/services/__pycache__/evaluate_pronunciation.cpython-39.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
app/services/__pycache__/transcribe.cpython-39.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
app/services/evaluate_pronunciation.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import asyncio
|
| 4 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 5 |
+
|
| 6 |
+
from models.ssl_singleton import ssl_model
|
| 7 |
+
from utils.general_utils import process_audio, clean_text
|
| 8 |
+
from modules.pronunciation_coach.pronunciation_assessor import PronunciationAssessor
|
| 9 |
+
from utils.cache import audio_cache
|
| 10 |
+
# process -> call infereence -> structure output -> return
|
| 11 |
+
|
| 12 |
+
class PronunciationEvalService:
|
| 13 |
+
def __init__(self, transcript, audio):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the transcription service.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
transcript (str): Ground truth transcript.
|
| 19 |
+
audio (UploadFile): Uploaded audio file.
|
| 20 |
+
"""
|
| 21 |
+
self.ssl_model = ssl_model
|
| 22 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
self.device = 'cpu' # TEMP for testing
|
| 24 |
+
self.transcript = clean_text(transcript).strip()
|
| 25 |
+
self.audio = audio
|
| 26 |
+
self.filename = audio.filename
|
| 27 |
+
self.uttered_phonemes = None
|
| 28 |
+
self.assessor = None
|
| 29 |
+
|
| 30 |
+
async def get_uttered_phonemes(self):
|
| 31 |
+
# check if cache has filename
|
| 32 |
+
audio = self.audio
|
| 33 |
+
start_time = time.time()
|
| 34 |
+
audio_inputs = None
|
| 35 |
+
if await audio_cache.contains(self.filename):
|
| 36 |
+
async with audio_cache.lock:
|
| 37 |
+
if audio_cache.cache[self.filename]["uttered_phonemes"] != None:
|
| 38 |
+
logging.info(f"Audio '{self.filename}' found in cache.")
|
| 39 |
+
end_time = time.time()
|
| 40 |
+
logging.info(f"Time from for getting uttered phonemes: {end_time - start_time} seconds")
|
| 41 |
+
return audio_cache.cache[self.filename]["uttered_phonemes"]
|
| 42 |
+
else:
|
| 43 |
+
logging.info(f"Audio '{self.filename}' found in cache but not inferenced. Running inference...")
|
| 44 |
+
audio_inputs = audio_cache.cache[self.filename]["audio_input"]
|
| 45 |
+
else:
|
| 46 |
+
logging.info(f"Audio '{self.filename}' not found in cache. Running inference...")
|
| 47 |
+
|
| 48 |
+
if audio_inputs is None:
|
| 49 |
+
cache_entry = await process_audio(audio, self.device)
|
| 50 |
+
audio_inputs = cache_entry["audio_input"]
|
| 51 |
+
|
| 52 |
+
uttered_phonemes = await self.ssl_model.infer_and_save_to_cache(self.filename, audio_inputs, self.device)
|
| 53 |
+
end_time = time.time()
|
| 54 |
+
logging.info(f"Time for getting uttered phonemes: {end_time - start_time} seconds")
|
| 55 |
+
return uttered_phonemes
|
| 56 |
+
|
| 57 |
+
async def generate_labels(self):
|
| 58 |
+
self.uttered_phonemes = await self.get_uttered_phonemes()
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
self.assessor = PronunciationAssessor(self.transcript, self.uttered_phonemes)
|
| 61 |
+
self.assessor.convert_transcript_into_phonemes()
|
| 62 |
+
self.assessor.clean_ipa_phonemes()
|
| 63 |
+
self.assessor.split_phoneme_sequence()
|
| 64 |
+
|
| 65 |
+
labels = self.assessor.generate_labels_for_api()
|
| 66 |
+
end_time = time.time()
|
| 67 |
+
print("Time taken for label generation after getting uttered phonemes:", end_time - start_time)
|
| 68 |
+
|
| 69 |
+
return labels
|
app/services/transcribe.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import asyncio
|
| 4 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 5 |
+
|
| 6 |
+
from models.transcriber_singleton import transcriber_model
|
| 7 |
+
from models.ssl_singleton import ssl_model
|
| 8 |
+
from utils.general_utils import process_audio, clean_text
|
| 9 |
+
|
| 10 |
+
# from utils.transcribe_utils import transcribe_into_English, clean_text
|
| 11 |
+
# process -> call infereence -> structure output -> return
|
| 12 |
+
|
| 13 |
+
class TranscriptionService:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
"""
|
| 16 |
+
Initialize the transcription service.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
self.transcriber_model = transcriber_model
|
| 20 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
self.device = 'cpu' # TEMP for testing
|
| 22 |
+
|
| 23 |
+
async def transcribe_audio(self, audio):
|
| 24 |
+
"""
|
| 25 |
+
Process the uploaded audio file and return its transcription.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
audio (UploadFile): Uploaded audio file.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
str: The transcript.
|
| 32 |
+
"""
|
| 33 |
+
logging.info("Received transcription request!")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# Step 1: Process the audio and check cache
|
| 37 |
+
start_time = time.time()
|
| 38 |
+
cache_entry = await process_audio(audio, self.device)
|
| 39 |
+
audio_input = cache_entry["audio_input"]
|
| 40 |
+
|
| 41 |
+
# Step 2: Start SSL inference in the background
|
| 42 |
+
asyncio.create_task(ssl_model.infer_and_save_to_cache(audio.filename, audio_input, self.device))
|
| 43 |
+
|
| 44 |
+
# Step 3: Get the transcript using Whisper
|
| 45 |
+
end_time = time.time()
|
| 46 |
+
logging.info(f"Time from call to finish processing audio: {end_time - start_time} seconds")
|
| 47 |
+
transcript = self.transcriber_model.transcribe_into_English(audio_input)
|
| 48 |
+
# Log processing time
|
| 49 |
+
another_end_time = time.time()
|
| 50 |
+
logging.info(f"Transcript: {transcript}, Time taken from processed audio to finish transcription: {another_end_time - end_time} seconds")
|
| 51 |
+
|
| 52 |
+
return transcript
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logging.error(f"Error during transcription: {e}")
|
| 56 |
+
raise
|
notebook-inference.ipynb → app/tester-notebook.ipynb
RENAMED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/utils/__init__.py
ADDED
|
File without changes
|
app/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
app/utils/__pycache__/cache.cpython-39.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
app/utils/__pycache__/general_utils.cpython-39.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
app/utils/cache.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from cachetools import TTLCache
|
| 3 |
+
|
| 4 |
+
class CacheManager:
|
| 5 |
+
_instance = None
|
| 6 |
+
|
| 7 |
+
def __new__(cls, *args, **kwargs):
|
| 8 |
+
if not cls._instance:
|
| 9 |
+
cls._instance = super(CacheManager, cls).__new__(cls, *args, **kwargs)
|
| 10 |
+
cls._instance._initialize()
|
| 11 |
+
return cls._instance
|
| 12 |
+
|
| 13 |
+
def _initialize(self):
|
| 14 |
+
# Initialize the cache and lock only once
|
| 15 |
+
self.cache = TTLCache(maxsize=100, ttl=300)
|
| 16 |
+
self.lock = asyncio.Lock()
|
| 17 |
+
|
| 18 |
+
async def set(self, key, value):
|
| 19 |
+
async with self.lock:
|
| 20 |
+
self.cache[key] = value
|
| 21 |
+
|
| 22 |
+
async def get(self, key):
|
| 23 |
+
async with self.lock:
|
| 24 |
+
return self.cache.get(key, None)
|
| 25 |
+
|
| 26 |
+
async def contains(self, key):
|
| 27 |
+
async with self.lock:
|
| 28 |
+
return key in self.cache
|
| 29 |
+
|
| 30 |
+
async def delete(self, key):
|
| 31 |
+
async with self.lock:
|
| 32 |
+
if key in self.cache:
|
| 33 |
+
del self.cache[key]
|
| 34 |
+
|
| 35 |
+
def set_without_lock(self, key, value):
|
| 36 |
+
self.cache[key] = value
|
| 37 |
+
|
| 38 |
+
def get_without_lock(self, key):
|
| 39 |
+
return self.cache.get(key, None)
|
| 40 |
+
|
| 41 |
+
def contains_without_lock(self, key):
|
| 42 |
+
return key in self.cache
|
| 43 |
+
|
| 44 |
+
def delete_without_lock(self, key):
|
| 45 |
+
if key in self.cache:
|
| 46 |
+
del self.cache[key]
|
| 47 |
+
|
| 48 |
+
audio_cache = CacheManager()
|
app/utils/general_utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from tempfile import NamedTemporaryFile
|
| 5 |
+
import numpy as np
|
| 6 |
+
import librosa
|
| 7 |
+
from pydub import AudioSegment
|
| 8 |
+
import subprocess
|
| 9 |
+
import os
|
| 10 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
from utils.cache import audio_cache
|
| 13 |
+
import asyncio
|
| 14 |
+
|
| 15 |
+
async def process_audio(audio, device):
|
| 16 |
+
"""
|
| 17 |
+
Process an uploaded audio file and prepare input for the model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
audio: The uploaded audio file.
|
| 21 |
+
device: The device (e.g., 'cuda' or 'cpu') to move tensors to.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
cache_entry: A dictionary containing processed audio and model input.
|
| 25 |
+
"""
|
| 26 |
+
filename = audio.filename
|
| 27 |
+
|
| 28 |
+
# Check cache for processed audio
|
| 29 |
+
if await audio_cache.contains(filename):
|
| 30 |
+
logging.info(f"Audio '{filename}' found in cache.")
|
| 31 |
+
return await audio_cache.get(filename)
|
| 32 |
+
|
| 33 |
+
# Prevent race conditions during cache writes
|
| 34 |
+
async with audio_cache.lock:
|
| 35 |
+
# Double-check after acquiring lock
|
| 36 |
+
if audio_cache.contains_without_lock(filename):
|
| 37 |
+
logging.info(f"Audio '{filename}' found in cache after lock.")
|
| 38 |
+
return audio_cache.contains_without_lock(filename)
|
| 39 |
+
logging.info(f"Processing audio '{filename}'.")
|
| 40 |
+
|
| 41 |
+
# Read and preprocess the audio
|
| 42 |
+
audio_bytes = BytesIO(await audio.read())
|
| 43 |
+
audio_segment = AudioSegment.from_file(audio_bytes, format="m4a")
|
| 44 |
+
audio_samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
|
| 45 |
+
max_val = np.iinfo(np.int16).max
|
| 46 |
+
audio_samples /= max_val
|
| 47 |
+
|
| 48 |
+
if audio_segment.channels > 1:
|
| 49 |
+
audio_samples = audio_samples.reshape(-1, audio_segment.channels).mean(axis=1)
|
| 50 |
+
|
| 51 |
+
audio_input = librosa.resample(audio_samples, orig_sr=audio_segment.frame_rate, target_sr=16000)
|
| 52 |
+
# input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values.to(device)
|
| 53 |
+
|
| 54 |
+
# Cache the processed audio
|
| 55 |
+
cache_entry = {"audio_input": audio_input, "input_values": None, "ssl_logits": None}
|
| 56 |
+
audio_cache.set_without_lock(filename, cache_entry)
|
| 57 |
+
return cache_entry
|
| 58 |
+
|
| 59 |
+
def clean_text(text: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Remove punctuation from the input string except for special characters
|
| 62 |
+
that are part of a word, such as ' in I'm or - in hard-working.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
text (str): Input string to clean.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
str: Cleaned string with allowed special characters retained.
|
| 69 |
+
"""
|
| 70 |
+
# Allow letters, spaces, apostrophes, and hyphens within words
|
| 71 |
+
cleaned_text = re.sub(r'[^\w\s\'-]', '', text) # Remove punctuation except ' and -
|
| 72 |
+
cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Normalize spaces
|
| 73 |
+
return cleaned_text.lower().strip()
|
inference.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import librosa
|
| 3 |
-
import soundfile as sf
|
| 4 |
-
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 5 |
-
import re
|
| 6 |
-
import numpy as np
|
| 7 |
-
import cmudict
|
| 8 |
-
|
| 9 |
-
# Load the processor and model
|
| 10 |
-
MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
|
| 11 |
-
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
| 12 |
-
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
|
| 13 |
-
model.eval()
|
| 14 |
-
|
| 15 |
-
# Check device availability
|
| 16 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
-
model.to(device)
|
| 18 |
-
|
| 19 |
-
def load_audio(audio_path, target_sr=16000):
|
| 20 |
-
"""Load an audio file and resample it to 16kHz."""
|
| 21 |
-
audio, sr = librosa.load(audio_path, sr=target_sr)
|
| 22 |
-
return audio
|
| 23 |
-
|
| 24 |
-
# Original ARPAbet to IPA mapping from SoapBox Labs
|
| 25 |
-
arpabet_to_ipa = {
|
| 26 |
-
"AA": "a", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
|
| 27 |
-
"EH": "ɛ", "ER": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
|
| 28 |
-
"OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "t͡ʃ", "D": "d",
|
| 29 |
-
"DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k",
|
| 30 |
-
"L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "ɹ",
|
| 31 |
-
"S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w",
|
| 32 |
-
"Y": "j", "Z": "z", "ZH": "ʒ"
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
# Invert the dictionary to map IPA to ARPAbet
|
| 36 |
-
ipa_to_arpabet = {v: k for k, v in arpabet_to_ipa.items()}
|
| 37 |
-
|
| 38 |
-
def convert_ipa_to_arpabet(ipa_words):
|
| 39 |
-
"""
|
| 40 |
-
Convert a list of IPA words (strings of concatenated phonemes) to ARPAbet words.
|
| 41 |
-
|
| 42 |
-
:param ipa_words: List of IPA words where each word is a string of concatenated phonemes.
|
| 43 |
-
:return: List of lists, where each inner list contains ARPAbet phonemes for a word.
|
| 44 |
-
"""
|
| 45 |
-
arpabet_words = []
|
| 46 |
-
for word in ipa_words:
|
| 47 |
-
# Break the word into phonemes
|
| 48 |
-
phonemes = [] # Collect matched phonemes
|
| 49 |
-
i = 0
|
| 50 |
-
while i < len(word):
|
| 51 |
-
matched = False
|
| 52 |
-
# Match multi-character IPA phonemes first
|
| 53 |
-
for ipa_phoneme in sorted(ipa_to_arpabet.keys(), key=len, reverse=True):
|
| 54 |
-
if word[i:].startswith(ipa_phoneme):
|
| 55 |
-
phonemes.append(ipa_to_arpabet[ipa_phoneme])
|
| 56 |
-
i += len(ipa_phoneme)
|
| 57 |
-
matched = True
|
| 58 |
-
break
|
| 59 |
-
# If no match, add an unknown marker and move forward
|
| 60 |
-
if not matched:
|
| 61 |
-
phonemes.append("<UNK>")
|
| 62 |
-
i += 1
|
| 63 |
-
# Append the list of phonemes for the word
|
| 64 |
-
arpabet_words.append(phonemes)
|
| 65 |
-
return arpabet_words
|
| 66 |
-
|
| 67 |
-
def remove_numbers_from_phonemes(phon_list):
|
| 68 |
-
"""
|
| 69 |
-
Remove all numbers from phonemes in a nested list.
|
| 70 |
-
|
| 71 |
-
Parameters:
|
| 72 |
-
phon_list (list of lists): Nested list of phonemes.
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
list of lists: Updated nested list with numbers removed from phonemes.
|
| 76 |
-
"""
|
| 77 |
-
cleaned_phon_list = []
|
| 78 |
-
for word_phonemes in phon_list:
|
| 79 |
-
cleaned_word = [re.sub(r'\d', '', phoneme) for phoneme in word_phonemes]
|
| 80 |
-
cleaned_phon_list.append(cleaned_word)
|
| 81 |
-
return cleaned_phon_list
|
| 82 |
-
|
| 83 |
-
def align_phoneme_sequences(truth_words, uttered_words, gap_penalty=1, substitution_cost=1):
|
| 84 |
-
"""
|
| 85 |
-
Align phoneme sequences separated by words.
|
| 86 |
-
|
| 87 |
-
Parameters:
|
| 88 |
-
truth_words (list of lists): Ground truth phoneme sequences grouped by words.
|
| 89 |
-
uttered_words (list of lists): Uttered phoneme sequences grouped by words.
|
| 90 |
-
gap_penalty (int): Penalty for gaps.
|
| 91 |
-
substitution_cost (int): Cost for substitutions.
|
| 92 |
-
|
| 93 |
-
Returns:
|
| 94 |
-
alignment (list of tuples): Aligned phoneme sequences with '-' for gaps.
|
| 95 |
-
"""
|
| 96 |
-
def align_two_sequences(seq1, seq2):
|
| 97 |
-
"""
|
| 98 |
-
Align two sequences using dynamic programming.
|
| 99 |
-
"""
|
| 100 |
-
n = len(seq1)
|
| 101 |
-
m = len(seq2)
|
| 102 |
-
dp = np.zeros((n + 1, m + 1))
|
| 103 |
-
|
| 104 |
-
# Initialize DP table
|
| 105 |
-
for i in range(n + 1):
|
| 106 |
-
dp[i][0] = i * gap_penalty
|
| 107 |
-
for j in range(m + 1):
|
| 108 |
-
dp[0][j] = j * gap_penalty
|
| 109 |
-
|
| 110 |
-
# Fill DP table
|
| 111 |
-
for i in range(1, n + 1):
|
| 112 |
-
for j in range(1, m + 1):
|
| 113 |
-
match_cost = 0 if seq1[i - 1] == seq2[j - 1] else substitution_cost
|
| 114 |
-
dp[i][j] = min(
|
| 115 |
-
dp[i - 1][j - 1] + match_cost, # Match or substitution
|
| 116 |
-
dp[i - 1][j] + gap_penalty, # Deletion
|
| 117 |
-
dp[i][j - 1] + gap_penalty # Insertion
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Traceback to find alignment
|
| 121 |
-
alignment_seq1 = []
|
| 122 |
-
alignment_seq2 = []
|
| 123 |
-
i, j = n, m
|
| 124 |
-
while i > 0 or j > 0:
|
| 125 |
-
if i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + (0 if seq1[i - 1] == seq2[j - 1] else substitution_cost):
|
| 126 |
-
alignment_seq1.append(seq1[i - 1])
|
| 127 |
-
alignment_seq2.append(seq2[j - 1])
|
| 128 |
-
i -= 1
|
| 129 |
-
j -= 1
|
| 130 |
-
elif i > 0 and dp[i][j] == dp[i - 1][j] + gap_penalty:
|
| 131 |
-
alignment_seq1.append(seq1[i - 1])
|
| 132 |
-
alignment_seq2.append('-')
|
| 133 |
-
i -= 1
|
| 134 |
-
else:
|
| 135 |
-
alignment_seq1.append('-')
|
| 136 |
-
alignment_seq2.append(seq2[j - 1])
|
| 137 |
-
j -= 1
|
| 138 |
-
|
| 139 |
-
return alignment_seq1[::-1], alignment_seq2[::-1]
|
| 140 |
-
|
| 141 |
-
# Align each word pair
|
| 142 |
-
alignment = []
|
| 143 |
-
for truth_word, uttered_word in zip(truth_words, uttered_words):
|
| 144 |
-
aligned_truth, aligned_uttered = align_two_sequences(truth_word, uttered_word)
|
| 145 |
-
alignment.append((aligned_truth, aligned_uttered))
|
| 146 |
-
|
| 147 |
-
return alignment
|
| 148 |
-
|
| 149 |
-
def generate_phoneme_labels(data):
|
| 150 |
-
"""
|
| 151 |
-
Generate phoneme labels for comparison of expected and uttered phonemes.
|
| 152 |
-
|
| 153 |
-
Parameters:
|
| 154 |
-
data (list of tuples): Each tuple contains (expected phonemes, uttered phonemes).
|
| 155 |
-
|
| 156 |
-
Returns:
|
| 157 |
-
list of tuples: Each tuple contains (phonemes, labels).
|
| 158 |
-
Phonemes are from the expected list, and labels are binary (0: correct, 1: incorrect).
|
| 159 |
-
"""
|
| 160 |
-
results = []
|
| 161 |
-
for expected, uttered in data:
|
| 162 |
-
labels = [
|
| 163 |
-
0 if exp == utt else 1
|
| 164 |
-
for exp, utt in zip(expected, uttered)
|
| 165 |
-
]
|
| 166 |
-
results.append((expected, labels))
|
| 167 |
-
return results
|
| 168 |
-
|
| 169 |
-
def convert_words_to_phonemes(words, cmu_dict):
|
| 170 |
-
phonemes = []
|
| 171 |
-
for word in words:
|
| 172 |
-
if word in cmu_dict:
|
| 173 |
-
phonemes.extend(cmu_dict[word][0]) # Use the first phoneme representation
|
| 174 |
-
else:
|
| 175 |
-
phonemes.append('<UNK>') # Append 'UNK' for unknown words
|
| 176 |
-
return phonemes
|
| 177 |
-
|
| 178 |
-
# RUN
|
| 179 |
-
|
| 180 |
-
def predict():
|
| 181 |
-
cmu = cmudict.dict()
|
| 182 |
-
|
| 183 |
-
# Path to test audio file
|
| 184 |
-
audio_path = '/content/drive/MyDrive/Test Audio/test5-good.m4a' # Replace with your audio file path
|
| 185 |
-
|
| 186 |
-
# Define the script
|
| 187 |
-
transcript = "the person that sat on the floor is punched"
|
| 188 |
-
|
| 189 |
-
# Load audio and normalize
|
| 190 |
-
audio_input = load_audio(audio_path)
|
| 191 |
-
input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
|
| 192 |
-
input_values = input_values.to(device)
|
| 193 |
-
|
| 194 |
-
# Step 3: Perform inference
|
| 195 |
-
with torch.no_grad():
|
| 196 |
-
logits = model(input_values).logits
|
| 197 |
-
|
| 198 |
-
# Step 4: Decode the phonemes
|
| 199 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
| 200 |
-
uttured_transcript = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 201 |
-
|
| 202 |
-
# convert uttered ipa into SAMPA (for comparison)
|
| 203 |
-
uttured_phons = convert_ipa_to_arpabet(uttured_transcript.split())
|
| 204 |
-
|
| 205 |
-
# convert ground truth text into SAMPA (for comparison), and remove (ignore) stress markers (may upgrade to evaluate stress also later)
|
| 206 |
-
trans_phons = [convert_words_to_phonemes([word], cmu) for word in transcript.split()]
|
| 207 |
-
cleaned_trans_phons = remove_numbers_from_phonemes(trans_phons)
|
| 208 |
-
|
| 209 |
-
# Generate labels
|
| 210 |
-
alignment = align_phoneme_sequences(cleaned_trans_phons, uttured_phons)
|
| 211 |
-
phoneme_labels = generate_phoneme_labels(alignment)
|
| 212 |
-
|
| 213 |
-
print(phoneme_labels)
|
| 214 |
-
return phoneme_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|