bachtom125 commited on
Commit
83f43dd
·
1 Parent(s): 7bf7549

refactor: modularize all components

Browse files
Files changed (38) hide show
  1. Dockerfile +1 -1
  2. app/__pycache__/app.cpython-39.pyc +0 -0
  3. app/__pycache__/main.cpython-39.pyc +0 -0
  4. app/main.py +49 -0
  5. app/models/__init__.py +0 -0
  6. app/models/__pycache__/__init__.cpython-39.pyc +0 -0
  7. app/models/__pycache__/ssl_singleton.cpython-39.pyc +0 -0
  8. app/models/__pycache__/transcriber_singleton.cpython-39.pyc +0 -0
  9. app/models/ssl_singleton.py +48 -0
  10. app/models/transcriber_singleton.py +45 -0
  11. app/modules/__init__.py +0 -0
  12. app/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  13. app/modules/pronunciation_coach/__init__.py +0 -0
  14. app/modules/pronunciation_coach/__pycache__/__init__.cpython-39.pyc +0 -0
  15. app/modules/pronunciation_coach/__pycache__/pronunciation_assessor.cpython-39.pyc +0 -0
  16. app/modules/pronunciation_coach/__pycache__/pronunciation_assessor_utils.cpython-39.pyc +0 -0
  17. app.py → app/modules/pronunciation_coach/pronunciation_assessor.py +3 -395
  18. app/modules/pronunciation_coach/pronunciation_assessor_utils.py +73 -0
  19. app/routes/__init__.py +0 -0
  20. app/routes/__pycache__/__init__.cpython-39.pyc +0 -0
  21. app/routes/__pycache__/predict.cpython-39.pyc +0 -0
  22. app/routes/__pycache__/transcribe.cpython-39.pyc +0 -0
  23. app/routes/predict.py +58 -0
  24. app/routes/transcribe.py +61 -0
  25. app/services/__init__.py +0 -0
  26. app/services/__pycache__/__init__.cpython-39.pyc +0 -0
  27. app/services/__pycache__/evaluate_pronunciation.cpython-39.pyc +0 -0
  28. app/services/__pycache__/transcribe.cpython-39.pyc +0 -0
  29. app/services/evaluate_pronunciation.py +69 -0
  30. app/services/transcribe.py +56 -0
  31. notebook-inference.ipynb → app/tester-notebook.ipynb +0 -0
  32. app/utils/__init__.py +0 -0
  33. app/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  34. app/utils/__pycache__/cache.cpython-39.pyc +0 -0
  35. app/utils/__pycache__/general_utils.cpython-39.pyc +0 -0
  36. app/utils/cache.py +48 -0
  37. app/utils/general_utils.py +73 -0
  38. inference.py +0 -214
Dockerfile CHANGED
@@ -29,4 +29,4 @@ COPY . .
29
  EXPOSE 7860
30
 
31
  # Run the FastAPI application
32
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
29
  EXPOSE 7860
30
 
31
  # Run the FastAPI application
32
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__pycache__/app.cpython-39.pyc ADDED
Binary file (1.71 kB). View file
 
app/__pycache__/main.cpython-39.pyc ADDED
Binary file (1.8 kB). View file
 
app/main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, UploadFile, Form, HTTPException
3
+ from fastapi.responses import JSONResponse
4
+ import uvicorn
5
+ from typing import List
6
+ import torch
7
+ import soundfile as sf
8
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ import re
10
+ import numpy as np
11
+ import cmudict
12
+ from io import BytesIO
13
+ import os
14
+ import logging
15
+ from joblib import Memory
16
+ from difflib import SequenceMatcher
17
+ import eng_to_ipa as ipa_conv
18
+ import os
19
+ import copy
20
+ from IPython.display import HTML, display
21
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
22
+ from pydub import AudioSegment
23
+ from Bio import pairwise2
24
+ from Bio.pairwise2 import format_alignment
25
+ import asyncio
26
+ from cachetools import TTLCache
27
+
28
+ # Set the Numba cache directory to a writable location
29
+ os.environ["NUMBA_CACHE_DIR"] = "/tmp"
30
+ import librosa
31
+ logging.basicConfig(level=logging.INFO)
32
+
33
+ # package imports
34
+ from routes.transcribe import router as transcriber_router
35
+ from routes.predict import router as pronunciation_evaluation_router
36
+ # Initialize FastAPI app
37
+ app = FastAPI(title="Talkiee AI", version="1.0.0")
38
+
39
+ # health check
40
+ @app.get("/")
41
+ def home():
42
+ return "Healthy bro!"
43
+
44
+ app.include_router(transcriber_router, tags=["transcribe"])
45
+ app.include_router(pronunciation_evaluation_router, tags=["pronunciation_evaluation"])
46
+ if __name__ == '__main__':
47
+ port = os.environ.get("PORT", 10000) # Default to 10000 if PORT is not set
48
+ logging.info(f"Starting server on PORT {port}")
49
+ uvicorn.run("main:app", host="0.0.0.0", port=int(port), log_level="info")
app/models/__init__.py ADDED
File without changes
app/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (171 Bytes). View file
 
app/models/__pycache__/ssl_singleton.cpython-39.pyc ADDED
Binary file (2.04 kB). View file
 
app/models/__pycache__/transcriber_singleton.cpython-39.pyc ADDED
Binary file (2.04 kB). View file
 
app/models/ssl_singleton.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
+ from utils.general_utils import process_audio
5
+ import asyncio
6
+ import librosa
7
+ from utils.cache import audio_cache
8
+
9
+ class SSLSingleton:
10
+ _instance = None
11
+
12
+ def __new__(cls, model_name="mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme", device=None):
13
+ if cls._instance is None:
14
+ cls._instance = super(SSLSingleton, cls).__new__(cls)
15
+ cls._instance._initialize(model_name, device)
16
+ return cls._instance
17
+
18
+ def _initialize(self, model_name, device):
19
+ # Set device (CPU or GPU)
20
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ self.device = "cpu"
22
+ # Load processor and model
23
+ print("Loading SSL processor and model...") # This will only happen once
24
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
25
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
26
+ self.model.eval()
27
+ self.model.to(self.device) # Move model to the specified device
28
+
29
+ # an infernce function taking in processed audio input and returning the predictions
30
+ def infer(self, audio_input, device):
31
+ inputs = self.processor(audio_input, sampling_rate=16000, return_tensors="pt")
32
+ inputs = inputs.to(self.device)
33
+
34
+ with torch.no_grad():
35
+ logits = self.model(inputs.input_values).logits
36
+
37
+ predicted_ids = torch.argmax(logits, dim=-1)
38
+ uttered_phonemes = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
39
+ return uttered_phonemes
40
+
41
+ async def infer_and_save_to_cache(self, file_name, audio_input, device):
42
+ uttered_phonemes = self.infer(audio_input, device)
43
+ async with audio_cache.lock:
44
+ new_cache = audio_cache.cache[file_name]
45
+ new_cache["uttered_phonemes"] = uttered_phonemes
46
+ audio_cache.cache[file_name] = new_cache
47
+ return uttered_phonemes
48
+ ssl_model = SSLSingleton()
app/models/transcriber_singleton.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
4
+ from utils.general_utils import process_audio
5
+ import asyncio
6
+ import librosa
7
+
8
+ class TranscriberSingleton:
9
+ _instance = None
10
+
11
+ def __new__(cls, model_name="openai/whisper-tiny.en", device=None):
12
+ if cls._instance is None:
13
+ cls._instance = super(TranscriberSingleton, cls).__new__(cls)
14
+ cls._instance._initialize(model_name, device)
15
+ return cls._instance
16
+
17
+ def _initialize(self, model_name, device):
18
+ # Set device (CPU or GPU)
19
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.device = "cpu"
21
+ # Load processor and model
22
+ print(f"Loading Whisper processor and model into {device}...") # This will only happen once
23
+ self.processor = AutoProcessor.from_pretrained(model_name)
24
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name)
25
+ self.model.eval()
26
+ self.model.to(self.device) # Move model to the specified device
27
+
28
+ def transcribe_into_English(self, audio_input):
29
+ # Load audio file
30
+ audio_input = self.processor(audio_input, sampling_rate=16000, return_tensors="pt", language="en").to(self.device)
31
+
32
+ # Perform transcription
33
+ with torch.no_grad():
34
+ generated_ids = self.model.generate(audio_input.input_features)
35
+
36
+ # Decode the transcription
37
+ transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
+ return transcription.lower().strip()
39
+
40
+ def transcribe_from_file_path(self, file_path, target_sr=16000):
41
+ with open(file_path, "rb") as f:
42
+ audio_input, sr = librosa.load(file_path, sr=target_sr)
43
+ return self.transcribe_into_English(audio_input)
44
+
45
+ transcriber_model = TranscriberSingleton()
app/modules/__init__.py ADDED
File without changes
app/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (172 Bytes). View file
 
app/modules/pronunciation_coach/__init__.py ADDED
File without changes
app/modules/pronunciation_coach/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (192 Bytes). View file
 
app/modules/pronunciation_coach/__pycache__/pronunciation_assessor.cpython-39.pyc ADDED
Binary file (22.6 kB). View file
 
app/modules/pronunciation_coach/__pycache__/pronunciation_assessor_utils.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
app.py → app/modules/pronunciation_coach/pronunciation_assessor.py RENAMED
@@ -1,7 +1,3 @@
1
-
2
- from fastapi import FastAPI, UploadFile, Form, HTTPException
3
- from fastapi.responses import JSONResponse
4
- import uvicorn
5
  from typing import List
6
  import torch
7
  import soundfile as sf
@@ -10,7 +6,6 @@ import re
10
  import numpy as np
11
  import cmudict
12
  from io import BytesIO
13
- import os
14
  import logging
15
  from joblib import Memory
16
  from difflib import SequenceMatcher
@@ -24,267 +19,10 @@ from Bio import pairwise2
24
  from Bio.pairwise2 import format_alignment
25
  import asyncio
26
  from cachetools import TTLCache
27
-
28
- # Set the Numba cache directory to a writable location
29
- os.environ["NUMBA_CACHE_DIR"] = "/tmp"
30
- import librosa
31
-
32
- logging.basicConfig(level=logging.INFO)
33
-
34
- cmu = cmudict.dict()
35
-
36
- # Initialize FastAPI app
37
- app = FastAPI()
38
-
39
- # Load the processor and model
40
- MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
41
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
42
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
43
- model.eval()
44
-
45
- # Check device availability
46
- # device = "cuda" if torch.cuda.is_available() else "cpu"
47
- device = 'cpu' # TEMP for testing
48
- model.to(device)
49
-
50
- whisper_processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
51
- whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-tiny.en")
52
- whisper_model.eval()
53
- whisper_model.to(device)
54
-
55
- # =====================================
56
- # Section: Utils
57
- # =====================================
58
-
59
- # Initialize a cache with a 5-minute TTL and 100 items max
60
- audio_cache = TTLCache(maxsize=100, ttl=300)
61
- cache_lock = asyncio.Lock() # To prevent race conditions
62
-
63
- import os
64
- from tempfile import NamedTemporaryFile
65
- import subprocess
66
-
67
- async def process_audio(audio, device):
68
- """
69
- Process an uploaded audio file and prepare input for the model.
70
- Converts audio to WAV format using FFmpeg prior to processing.
71
-
72
- Args:
73
- audio: The uploaded audio file.
74
- device: The device (e.g., 'cuda' or 'cpu') to move tensors to.
75
-
76
- Returns:
77
- cache_entry: A dictionary containing processed audio and model input.
78
- """
79
- filename = audio.filename
80
-
81
- # Check cache for processed audio
82
- if filename in audio_cache:
83
- logging.info(f"Audio '{filename}' found in cache.")
84
- return audio_cache[filename]
85
-
86
- async with cache_lock: # Prevent race conditions during cache writes
87
- if filename in audio_cache: # Double-check after acquiring lock
88
- logging.info(f"Audio '{filename}' found in cache after lock.")
89
- return audio_cache[filename]
90
-
91
- logging.info(f"Processing audio '{filename}'.")
92
-
93
- # Read the audio file into a temporary file
94
- with NamedTemporaryFile(delete=False, suffix=".m4a") as temp_m4a:
95
- temp_m4a_path = temp_m4a.name
96
- temp_m4a.write(await audio.read())
97
-
98
- # Convert M4A to WAV using FFmpeg
99
- temp_wav_path = temp_m4a_path.replace(".m4a", ".wav")
100
- try:
101
- subprocess.run(
102
- [
103
- "ffmpeg", "-i", temp_m4a_path, # Input file
104
- "-ar", "16000", # Resample to 16kHz
105
- "-ac", "1", # Convert to mono
106
- temp_wav_path # Output file
107
- ],
108
- check=True,
109
- stdout=subprocess.PIPE,
110
- stderr=subprocess.PIPE
111
- )
112
- except subprocess.CalledProcessError as e:
113
- logging.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
114
- raise HTTPException(status_code=500, detail="Failed to process audio file.")
115
- finally:
116
- os.remove(temp_m4a_path) # Clean up the temporary M4A file
117
-
118
- try:
119
- # Load the WAV audio for further processing
120
- audio_segment = AudioSegment.from_file(temp_wav_path, format="wav")
121
- audio_samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
122
- max_val = np.iinfo(np.int16).max
123
- audio_samples /= max_val
124
-
125
- if audio_segment.channels > 1:
126
- audio_samples = audio_samples.reshape(-1, audio_segment.channels).mean(axis=1)
127
-
128
- audio_input = librosa.resample(audio_samples, orig_sr=audio_segment.frame_rate, target_sr=16000)
129
- input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values.to(device)
130
-
131
- # Cache the processed audio
132
- cache_entry = {"audio_input": audio_input, "input_values": input_values, "ssl_logits": None}
133
- audio_cache[filename] = cache_entry
134
- return cache_entry
135
-
136
- finally:
137
- # Clean up the temporary WAV file
138
- os.remove(temp_wav_path)
139
-
140
-
141
- async def run_ssl_inference(filename, input_values):
142
- """
143
- Run SSL model inference in the background and store the results in the cache.
144
-
145
- Args:
146
- filename: The name of the audio file.
147
- input_values: The processed input tensor for the SSL model.
148
- """
149
- try:
150
- logging.info(f"Running SSL inference for '{filename}' in the background.")
151
- with torch.no_grad():
152
- ssl_output = model(input_values).logits
153
-
154
- # Update the cache with the SSL inference result
155
- if filename in audio_cache:
156
- audio_cache[filename]["ssl_logits"] = ssl_output
157
- logging.info(f"SSL inference for '{filename}' completed and cached.")
158
- except Exception as e:
159
- logging.error(f"Error during SSL inference for '{filename}': {e}")
160
-
161
- def transcribe_into_English(audio_input):
162
- # Load audio file
163
- audio_input = whisper_processor(audio_input, sampling_rate=16000, return_tensors="pt", language="en").to(device)
164
-
165
- # Perform transcription
166
- with torch.no_grad():
167
- generated_ids = whisper_model.generate(audio_input.input_features)
168
-
169
- # Decode the transcription
170
- transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
171
- return transcription.lower().strip()
172
-
173
- def get_nested_position(nested_list, flat_index):
174
- """
175
- Finds the nested list and the index within it for a given flat index.
176
-
177
- Args:
178
- nested_list (list of lists): The list of lists.
179
- flat_index (int): The flattened index.
180
-
181
- Returns:
182
- tuple: (nested_list_index, element_index_in_nested_list)
183
- """
184
- cumulative_index = 0
185
-
186
- for list_index, sublist in enumerate(nested_list):
187
- # Check if the flat index falls within the current sublist
188
- if cumulative_index + len(sublist) > flat_index:
189
- # Calculate the index within the sublist
190
- element_index = flat_index - cumulative_index
191
- return list_index, element_index
192
- # Update cumulative index
193
- cumulative_index += len(sublist)
194
-
195
- raise IndexError("Index out of range for the flattened list.")
196
-
197
- def label_specific_elements_in_reference(reference, start_word_idx, start_element_idx, end_word_idx, end_element_idx, label):
198
- """
199
- Labels elements in a nested list between specified start and end indices (inclusive).
200
-
201
- Args:
202
- reference (list of lists): The original list of lists.
203
- start_word_idx (int): Index of the starting nested list.
204
- start_element_idx (int): Index of the starting element in the start list.
205
- end_word_idx (int): Index of the ending nested list.
206
- end_element_idx (int): Index of the ending element in the end list.
207
- label: The label to attach to the elements.
208
-
209
- Returns:
210
- list of lists: A new list of lists with labels attached where applicable.
211
- """
212
- labeled_reference = []
213
- for word_idx, sublist in enumerate(reference):
214
- labeled_sublist = []
215
-
216
- for element_idx, element in enumerate(sublist):
217
- if start_word_idx < end_word_idx:
218
- # Case 1: start_word_idx < end_word_idx
219
- if (
220
- (word_idx > start_word_idx and word_idx < end_word_idx) or
221
- (word_idx == start_word_idx and element_idx >= start_element_idx) or
222
- (word_idx == end_word_idx and element_idx <= end_element_idx)
223
- ):
224
- # Attach the label to elements within the inclusive range
225
- if isinstance(element, tuple):
226
- print(f"There is already a label at index ({word_idx}, {element_idx})")
227
- labeled_sublist.append((element, label))
228
- else:
229
- # Keep elements outside the range unchanged
230
- labeled_sublist.append(element)
231
- elif start_word_idx == end_word_idx:
232
- # Case 2: start_word_idx == end_word_idx
233
- if word_idx == start_word_idx and start_element_idx <= element_idx <= end_element_idx:
234
- # Attach the label to elements within the inclusive range
235
- if isinstance(element, tuple):
236
- print(f"There is already a label at index ({word_idx}, {element_idx})")
237
- labeled_sublist.append((element, label))
238
- else:
239
- # Keep elements outside the range unchanged
240
- labeled_sublist.append(element)
241
-
242
- labeled_reference.append(labeled_sublist)
243
-
244
- return labeled_reference
245
-
246
- def clean_text(text: str) -> str:
247
- """
248
- Remove punctuation from the input string except for special characters
249
- that are part of a word, such as ' in I'm or - in hard-working.
250
-
251
- Parameters:
252
- text (str): Input string to clean.
253
-
254
- Returns:
255
- str: Cleaned string with allowed special characters retained.
256
- """
257
- # Allow letters, spaces, apostrophes, and hyphens within words
258
- cleaned_text = re.sub(r'[^\w\s\'-]', '', text) # Remove punctuation except ' and -
259
- cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Normalize spaces
260
- return cleaned_text.lower().strip()
261
-
262
- # =====================================
263
- # Section: IPA Phonemes Utils
264
- # =====================================
265
-
266
-
267
- # WORKING: converting functions to class, currently done with the last function in the class
268
- import re
269
- from difflib import SequenceMatcher
270
- from IPython.display import HTML, display
271
- import copy
272
- from IPython.display import HTML, display
273
- from Bio import pairwise2
274
- from Bio.pairwise2 import format_alignment
275
-
276
- # WORKING: converting functions to class, currently done with the last function in the class
277
- import re
278
- from difflib import SequenceMatcher
279
- from IPython.display import HTML, display
280
- import copy
281
- from IPython.display import HTML, display
282
- from Bio import pairwise2
283
- from Bio.pairwise2 import format_alignment
284
- import cmudict
285
  cmu_dict = cmudict.dict()
286
 
287
- class PronunciationAssessment:
288
  def __init__(self, transcript, uttered_phonemes):
289
  # NOTE: removed all long signals ('ː') for compatibility with L2-artic's phoneme set (ssl model training set). American English.
290
  # ground truth phonemes are converted into arpabet first, and then into ipa using the arpabet_to_ipa dict, meaning the arpabet_to_ipa dict contains
@@ -1159,134 +897,4 @@ class PronunciationAssessment:
1159
 
1160
  # Display
1161
  display(HTML(html_content))
1162
-
1163
- # health check
1164
- @app.get("/")
1165
- def home():
1166
- return "Healthy bro!"
1167
-
1168
- import time # temp
1169
-
1170
- # taking in both audio and transcript from the user
1171
- @app.post("/predict")
1172
- async def predict(audio: UploadFile, transcript: str = Form(...)):
1173
- """
1174
- Predict phoneme labels from uploaded audio and provided transcript.
1175
-
1176
- Args:
1177
- audio (UploadFile): Uploaded audio file (WAV/MP3).
1178
- transcript (str): Ground truth transcript.
1179
-
1180
- Returns:
1181
- JSONResponse: Contains phoneme labels.
1182
- """
1183
- logging.info("Received prediction request!")
1184
-
1185
- # Validate file extension
1186
- allowed_extensions = {"wav", "mp3", "m4a"}
1187
- filename = audio.filename.lower()
1188
- start_time = time.time()
1189
-
1190
- if not filename.endswith(tuple(allowed_extensions)):
1191
- raise HTTPException(
1192
- status_code=400,
1193
- detail="Invalid file type. Only WAV and MP3 files are supported.",
1194
- )
1195
-
1196
- # Load and preprocess the audio
1197
- try:
1198
- cache_entry = await process_audio(audio, device)
1199
- input_values = cache_entry["input_values"]
1200
-
1201
- # Ensure SSL inference is completed
1202
- logits = cache_entry.get("ssl_logits")
1203
- if logits is None:
1204
- logging.info(f"SSL inference not cached for '{filename}', running now.")
1205
- with torch.no_grad():
1206
- logits = model(input_values).logits
1207
- cache_entry["ssl_logits"] = logits
1208
-
1209
- end_time = time.time()
1210
- print(f"Time from call to finish processing audio: {end_time - start_time} seconds")
1211
-
1212
- start_time = time.time()
1213
- transcript = clean_text(transcript).strip()
1214
-
1215
- # Decode the phonemes
1216
- predicted_ids = torch.argmax(logits, dim=-1)
1217
- uttered_phonemes = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
1218
- end_time = time.time()
1219
- print("Time taken for inference:", end_time - start_time)
1220
-
1221
- start_time = time.time()
1222
- # init PronunciationAssessment instance
1223
- cur = PronunciationAssessment(transcript, uttered_phonemes)
1224
- cur.convert_transcript_into_phonemes()
1225
- cur.clean_ipa_phonemes()
1226
- cur.split_phoneme_sequence()
1227
- print(cur.uttered_ipa_phonemes)
1228
- # print(cur.segmented_ground_truth_ipa_phonemes)
1229
- # print(cur.segmented_uttered_ipa_phonemes)
1230
-
1231
- # generate the final labels
1232
- labels = cur.generate_labels_for_api()
1233
- end_time = time.time()
1234
- print("Time taken for label generation:", end_time - start_time)
1235
- return JSONResponse(content={"labels": labels})
1236
-
1237
- except Exception as e:
1238
- logging.error(f"Error during prediction: {e}")
1239
- raise HTTPException(status_code=500, detail="An error occurred during processing.")
1240
-
1241
- # taking in audio only and returning the transcript
1242
- @app.post("/transcribe")
1243
- async def transcribe(audio: UploadFile):
1244
- """
1245
- Transcribe the uploaded audio and return the transcript.
1246
-
1247
- Args:
1248
- audio (UploadFile): Uploaded audio file (WAV/MP3).
1249
-
1250
- Returns:
1251
- JSONResponse: Contains the transcript.
1252
- """
1253
- logging.info("Received transcription request!")
1254
-
1255
- # Validate file extension
1256
- allowed_extensions = {"wav", "mp3", "m4a"}
1257
- filename = audio.filename.lower()
1258
- if not filename.endswith(tuple(allowed_extensions)):
1259
- raise HTTPException(
1260
- status_code=400,
1261
- detail="Invalid file type. Only WAV and MP3 files are supported.",
1262
- )
1263
-
1264
- # Load and preprocess the audio
1265
- try:
1266
- # Process the audio
1267
- start_time = time.time()
1268
- cache_entry = await process_audio(audio, device)
1269
- audio_input = cache_entry["audio_input"]
1270
- input_values = cache_entry["input_values"]
1271
-
1272
- # Start SSL inference in the background
1273
- asyncio.create_task(run_ssl_inference(audio.filename, input_values))
1274
-
1275
- # Get transcript
1276
- end_time = time.time()
1277
- print(f"Time from call to finish processing audio: {end_time - start_time} seconds")
1278
- transcript = transcribe_into_English(audio_input)
1279
- transcript = clean_text(transcript).strip()
1280
- another_end_time = time.time()
1281
- logging.info(f"Transcript: {transcript}, Time taken from processed audio to finish transcription: {another_end_time - end_time} seconds")
1282
-
1283
- return JSONResponse(content={"transcript": transcript})
1284
-
1285
- except Exception as e:
1286
- logging.error(f"Error during transcription: {e}")
1287
- raise HTTPException(status_code=500, detail="An error occurred during processing.")
1288
-
1289
- # if __name__ == '__main__':
1290
- # port = os.environ.get("PORT", 10000) # Default to 10000 if PORT is not set
1291
- # logging.info(f"Starting server on PORT {port}")
1292
- # uvicorn.run("app:app", host="0.0.0.0", port=int(port), log_level="info")
 
 
 
 
 
1
  from typing import List
2
  import torch
3
  import soundfile as sf
 
6
  import numpy as np
7
  import cmudict
8
  from io import BytesIO
 
9
  import logging
10
  from joblib import Memory
11
  from difflib import SequenceMatcher
 
19
  from Bio.pairwise2 import format_alignment
20
  import asyncio
21
  from cachetools import TTLCache
22
+ from modules.pronunciation_coach.pronunciation_assessor_utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  cmu_dict = cmudict.dict()
24
 
25
+ class PronunciationAssessor:
26
  def __init__(self, transcript, uttered_phonemes):
27
  # NOTE: removed all long signals ('ː') for compatibility with L2-artic's phoneme set (ssl model training set). American English.
28
  # ground truth phonemes are converted into arpabet first, and then into ipa using the arpabet_to_ipa dict, meaning the arpabet_to_ipa dict contains
 
897
 
898
  # Display
899
  display(HTML(html_content))
900
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/modules/pronunciation_coach/pronunciation_assessor_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def get_nested_position(nested_list, flat_index):
3
+ """
4
+ Finds the nested list and the index within it for a given flat index.
5
+
6
+ Args:
7
+ nested_list (list of lists): The list of lists.
8
+ flat_index (int): The flattened index.
9
+
10
+ Returns:
11
+ tuple: (nested_list_index, element_index_in_nested_list)
12
+ """
13
+ cumulative_index = 0
14
+
15
+ for list_index, sublist in enumerate(nested_list):
16
+ # Check if the flat index falls within the current sublist
17
+ if cumulative_index + len(sublist) > flat_index:
18
+ # Calculate the index within the sublist
19
+ element_index = flat_index - cumulative_index
20
+ return list_index, element_index
21
+ # Update cumulative index
22
+ cumulative_index += len(sublist)
23
+
24
+ raise IndexError("Index out of range for the flattened list.")
25
+
26
+ def label_specific_elements_in_reference(reference, start_word_idx, start_element_idx, end_word_idx, end_element_idx, label):
27
+ """
28
+ Labels elements in a nested list between specified start and end indices (inclusive).
29
+
30
+ Args:
31
+ reference (list of lists): The original list of lists.
32
+ start_word_idx (int): Index of the starting nested list.
33
+ start_element_idx (int): Index of the starting element in the start list.
34
+ end_word_idx (int): Index of the ending nested list.
35
+ end_element_idx (int): Index of the ending element in the end list.
36
+ label: The label to attach to the elements.
37
+
38
+ Returns:
39
+ list of lists: A new list of lists with labels attached where applicable.
40
+ """
41
+ labeled_reference = []
42
+ for word_idx, sublist in enumerate(reference):
43
+ labeled_sublist = []
44
+
45
+ for element_idx, element in enumerate(sublist):
46
+ if start_word_idx < end_word_idx:
47
+ # Case 1: start_word_idx < end_word_idx
48
+ if (
49
+ (word_idx > start_word_idx and word_idx < end_word_idx) or
50
+ (word_idx == start_word_idx and element_idx >= start_element_idx) or
51
+ (word_idx == end_word_idx and element_idx <= end_element_idx)
52
+ ):
53
+ # Attach the label to elements within the inclusive range
54
+ if isinstance(element, tuple):
55
+ print(f"There is already a label at index ({word_idx}, {element_idx})")
56
+ labeled_sublist.append((element, label))
57
+ else:
58
+ # Keep elements outside the range unchanged
59
+ labeled_sublist.append(element)
60
+ elif start_word_idx == end_word_idx:
61
+ # Case 2: start_word_idx == end_word_idx
62
+ if word_idx == start_word_idx and start_element_idx <= element_idx <= end_element_idx:
63
+ # Attach the label to elements within the inclusive range
64
+ if isinstance(element, tuple):
65
+ print(f"There is already a label at index ({word_idx}, {element_idx})")
66
+ labeled_sublist.append((element, label))
67
+ else:
68
+ # Keep elements outside the range unchanged
69
+ labeled_sublist.append(element)
70
+
71
+ labeled_reference.append(labeled_sublist)
72
+
73
+ return labeled_reference
app/routes/__init__.py ADDED
File without changes
app/routes/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (171 Bytes). View file
 
app/routes/__pycache__/predict.cpython-39.pyc ADDED
Binary file (2.19 kB). View file
 
app/routes/__pycache__/transcribe.cpython-39.pyc ADDED
Binary file (2.22 kB). View file
 
app/routes/predict.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, HTTPException, APIRouter, Depends
2
+ from fastapi.responses import JSONResponse
3
+ import uvicorn
4
+ from typing import List
5
+ import torch
6
+ import soundfile as sf
7
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ import re
9
+ import numpy as np
10
+ import cmudict
11
+ from io import BytesIO
12
+ import logging
13
+ from joblib import Memory
14
+ from difflib import SequenceMatcher
15
+ import eng_to_ipa as ipa_conv
16
+ import copy
17
+ from IPython.display import HTML, display
18
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
19
+ from pydub import AudioSegment
20
+ from Bio import pairwise2
21
+ from Bio.pairwise2 import format_alignment
22
+ import asyncio
23
+ from cachetools import TTLCache
24
+ import time
25
+ import os
26
+ from tempfile import NamedTemporaryFile
27
+ import subprocess
28
+ import librosa
29
+
30
+ # package imports
31
+ from services.evaluate_pronunciation import PronunciationEvalService
32
+ from utils.general_utils import clean_text
33
+
34
+ router = APIRouter()
35
+
36
+ @router.post("/predict", summary="Evaluate pronunciation")
37
+ async def evaluate_pronunciation(audio: UploadFile, transcript: str = Form(...)):
38
+ """
39
+ Predict phoneme labels from uploaded audio and provided transcript.
40
+
41
+ Args:
42
+ audio (UploadFile): Uploaded audio file (WAV/MP3).
43
+ transcript (str): Ground truth transcript.
44
+
45
+ Returns:
46
+ JSONResponse: Contains phoneme labels.
47
+ """
48
+ try:
49
+ # Call the service to process and transcribe the audio
50
+ service = PronunciationEvalService(transcript, audio)
51
+ labels = await service.generate_labels()
52
+
53
+ response = {'labels': labels}
54
+ return JSONResponse(content=response)
55
+
56
+ except Exception as e:
57
+ logging.error(f"Error during evaluation: {e}")
58
+ raise HTTPException(status_code=500, detail="An error occurred during processing.")
app/routes/transcribe.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, HTTPException, APIRouter, Depends
2
+ from fastapi.responses import JSONResponse
3
+ import uvicorn
4
+ from typing import List
5
+ import torch
6
+ import soundfile as sf
7
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ import re
9
+ import numpy as np
10
+ import cmudict
11
+ from io import BytesIO
12
+ import logging
13
+ from joblib import Memory
14
+ from difflib import SequenceMatcher
15
+ import eng_to_ipa as ipa_conv
16
+ import copy
17
+ from IPython.display import HTML, display
18
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
19
+ from pydub import AudioSegment
20
+ from Bio import pairwise2
21
+ from Bio.pairwise2 import format_alignment
22
+ import asyncio
23
+ from cachetools import TTLCache
24
+ import time
25
+ import os
26
+ from tempfile import NamedTemporaryFile
27
+ import subprocess
28
+ import librosa
29
+
30
+ # package imports
31
+ from services.transcribe import TranscriptionService
32
+ from utils.general_utils import clean_text
33
+
34
+ router = APIRouter()
35
+
36
+ service = TranscriptionService()
37
+ @router.post("/transcribe", summary="Trancribe audio into English")
38
+ async def transcribe(audio: UploadFile):
39
+ """
40
+ Transcribe the uploaded audio and return the transcript.
41
+
42
+ Args:
43
+ audio (UploadFile): Uploaded audio file.
44
+
45
+ Returns:
46
+ JSONResponse: Contains the transcript.
47
+ """
48
+ try:
49
+ # Call the service to process and transcribe the audio
50
+ transcript = await service.transcribe_audio(audio)
51
+ transcript = clean_text(transcript).strip()
52
+
53
+ response = {'transcript': transcript}
54
+ return JSONResponse(content=response)
55
+
56
+ except ValueError as ve:
57
+ logging.error(f"Validation error: {ve}")
58
+ raise HTTPException(status_code=400, detail=str(ve))
59
+ except Exception as e:
60
+ logging.error(f"Error during transcription: {e}")
61
+ raise HTTPException(status_code=500, detail="An error occurred during processing.")
app/services/__init__.py ADDED
File without changes
app/services/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (173 Bytes). View file
 
app/services/__pycache__/evaluate_pronunciation.cpython-39.pyc ADDED
Binary file (2.67 kB). View file
 
app/services/__pycache__/transcribe.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
app/services/evaluate_pronunciation.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import asyncio
4
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
+
6
+ from models.ssl_singleton import ssl_model
7
+ from utils.general_utils import process_audio, clean_text
8
+ from modules.pronunciation_coach.pronunciation_assessor import PronunciationAssessor
9
+ from utils.cache import audio_cache
10
+ # process -> call infereence -> structure output -> return
11
+
12
+ class PronunciationEvalService:
13
+ def __init__(self, transcript, audio):
14
+ """
15
+ Initialize the transcription service.
16
+
17
+ Args:
18
+ transcript (str): Ground truth transcript.
19
+ audio (UploadFile): Uploaded audio file.
20
+ """
21
+ self.ssl_model = ssl_model
22
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.device = 'cpu' # TEMP for testing
24
+ self.transcript = clean_text(transcript).strip()
25
+ self.audio = audio
26
+ self.filename = audio.filename
27
+ self.uttered_phonemes = None
28
+ self.assessor = None
29
+
30
+ async def get_uttered_phonemes(self):
31
+ # check if cache has filename
32
+ audio = self.audio
33
+ start_time = time.time()
34
+ audio_inputs = None
35
+ if await audio_cache.contains(self.filename):
36
+ async with audio_cache.lock:
37
+ if audio_cache.cache[self.filename]["uttered_phonemes"] != None:
38
+ logging.info(f"Audio '{self.filename}' found in cache.")
39
+ end_time = time.time()
40
+ logging.info(f"Time from for getting uttered phonemes: {end_time - start_time} seconds")
41
+ return audio_cache.cache[self.filename]["uttered_phonemes"]
42
+ else:
43
+ logging.info(f"Audio '{self.filename}' found in cache but not inferenced. Running inference...")
44
+ audio_inputs = audio_cache.cache[self.filename]["audio_input"]
45
+ else:
46
+ logging.info(f"Audio '{self.filename}' not found in cache. Running inference...")
47
+
48
+ if audio_inputs is None:
49
+ cache_entry = await process_audio(audio, self.device)
50
+ audio_inputs = cache_entry["audio_input"]
51
+
52
+ uttered_phonemes = await self.ssl_model.infer_and_save_to_cache(self.filename, audio_inputs, self.device)
53
+ end_time = time.time()
54
+ logging.info(f"Time for getting uttered phonemes: {end_time - start_time} seconds")
55
+ return uttered_phonemes
56
+
57
+ async def generate_labels(self):
58
+ self.uttered_phonemes = await self.get_uttered_phonemes()
59
+ start_time = time.time()
60
+ self.assessor = PronunciationAssessor(self.transcript, self.uttered_phonemes)
61
+ self.assessor.convert_transcript_into_phonemes()
62
+ self.assessor.clean_ipa_phonemes()
63
+ self.assessor.split_phoneme_sequence()
64
+
65
+ labels = self.assessor.generate_labels_for_api()
66
+ end_time = time.time()
67
+ print("Time taken for label generation after getting uttered phonemes:", end_time - start_time)
68
+
69
+ return labels
app/services/transcribe.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import asyncio
4
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
+
6
+ from models.transcriber_singleton import transcriber_model
7
+ from models.ssl_singleton import ssl_model
8
+ from utils.general_utils import process_audio, clean_text
9
+
10
+ # from utils.transcribe_utils import transcribe_into_English, clean_text
11
+ # process -> call infereence -> structure output -> return
12
+
13
+ class TranscriptionService:
14
+ def __init__(self):
15
+ """
16
+ Initialize the transcription service.
17
+ """
18
+
19
+ self.transcriber_model = transcriber_model
20
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ self.device = 'cpu' # TEMP for testing
22
+
23
+ async def transcribe_audio(self, audio):
24
+ """
25
+ Process the uploaded audio file and return its transcription.
26
+
27
+ Args:
28
+ audio (UploadFile): Uploaded audio file.
29
+
30
+ Returns:
31
+ str: The transcript.
32
+ """
33
+ logging.info("Received transcription request!")
34
+
35
+ try:
36
+ # Step 1: Process the audio and check cache
37
+ start_time = time.time()
38
+ cache_entry = await process_audio(audio, self.device)
39
+ audio_input = cache_entry["audio_input"]
40
+
41
+ # Step 2: Start SSL inference in the background
42
+ asyncio.create_task(ssl_model.infer_and_save_to_cache(audio.filename, audio_input, self.device))
43
+
44
+ # Step 3: Get the transcript using Whisper
45
+ end_time = time.time()
46
+ logging.info(f"Time from call to finish processing audio: {end_time - start_time} seconds")
47
+ transcript = self.transcriber_model.transcribe_into_English(audio_input)
48
+ # Log processing time
49
+ another_end_time = time.time()
50
+ logging.info(f"Transcript: {transcript}, Time taken from processed audio to finish transcription: {another_end_time - end_time} seconds")
51
+
52
+ return transcript
53
+
54
+ except Exception as e:
55
+ logging.error(f"Error during transcription: {e}")
56
+ raise
notebook-inference.ipynb → app/tester-notebook.ipynb RENAMED
The diff for this file is too large to render. See raw diff
 
app/utils/__init__.py ADDED
File without changes
app/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (170 Bytes). View file
 
app/utils/__pycache__/cache.cpython-39.pyc ADDED
Binary file (2.31 kB). View file
 
app/utils/__pycache__/general_utils.cpython-39.pyc ADDED
Binary file (2.42 kB). View file
 
app/utils/cache.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from cachetools import TTLCache
3
+
4
+ class CacheManager:
5
+ _instance = None
6
+
7
+ def __new__(cls, *args, **kwargs):
8
+ if not cls._instance:
9
+ cls._instance = super(CacheManager, cls).__new__(cls, *args, **kwargs)
10
+ cls._instance._initialize()
11
+ return cls._instance
12
+
13
+ def _initialize(self):
14
+ # Initialize the cache and lock only once
15
+ self.cache = TTLCache(maxsize=100, ttl=300)
16
+ self.lock = asyncio.Lock()
17
+
18
+ async def set(self, key, value):
19
+ async with self.lock:
20
+ self.cache[key] = value
21
+
22
+ async def get(self, key):
23
+ async with self.lock:
24
+ return self.cache.get(key, None)
25
+
26
+ async def contains(self, key):
27
+ async with self.lock:
28
+ return key in self.cache
29
+
30
+ async def delete(self, key):
31
+ async with self.lock:
32
+ if key in self.cache:
33
+ del self.cache[key]
34
+
35
+ def set_without_lock(self, key, value):
36
+ self.cache[key] = value
37
+
38
+ def get_without_lock(self, key):
39
+ return self.cache.get(key, None)
40
+
41
+ def contains_without_lock(self, key):
42
+ return key in self.cache
43
+
44
+ def delete_without_lock(self, key):
45
+ if key in self.cache:
46
+ del self.cache[key]
47
+
48
+ audio_cache = CacheManager()
app/utils/general_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+ import torch
4
+ from tempfile import NamedTemporaryFile
5
+ import numpy as np
6
+ import librosa
7
+ from pydub import AudioSegment
8
+ import subprocess
9
+ import os
10
+ from fastapi import FastAPI, UploadFile, Form, HTTPException
11
+ from io import BytesIO
12
+ from utils.cache import audio_cache
13
+ import asyncio
14
+
15
+ async def process_audio(audio, device):
16
+ """
17
+ Process an uploaded audio file and prepare input for the model.
18
+
19
+ Args:
20
+ audio: The uploaded audio file.
21
+ device: The device (e.g., 'cuda' or 'cpu') to move tensors to.
22
+
23
+ Returns:
24
+ cache_entry: A dictionary containing processed audio and model input.
25
+ """
26
+ filename = audio.filename
27
+
28
+ # Check cache for processed audio
29
+ if await audio_cache.contains(filename):
30
+ logging.info(f"Audio '{filename}' found in cache.")
31
+ return await audio_cache.get(filename)
32
+
33
+ # Prevent race conditions during cache writes
34
+ async with audio_cache.lock:
35
+ # Double-check after acquiring lock
36
+ if audio_cache.contains_without_lock(filename):
37
+ logging.info(f"Audio '{filename}' found in cache after lock.")
38
+ return audio_cache.contains_without_lock(filename)
39
+ logging.info(f"Processing audio '{filename}'.")
40
+
41
+ # Read and preprocess the audio
42
+ audio_bytes = BytesIO(await audio.read())
43
+ audio_segment = AudioSegment.from_file(audio_bytes, format="m4a")
44
+ audio_samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
45
+ max_val = np.iinfo(np.int16).max
46
+ audio_samples /= max_val
47
+
48
+ if audio_segment.channels > 1:
49
+ audio_samples = audio_samples.reshape(-1, audio_segment.channels).mean(axis=1)
50
+
51
+ audio_input = librosa.resample(audio_samples, orig_sr=audio_segment.frame_rate, target_sr=16000)
52
+ # input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values.to(device)
53
+
54
+ # Cache the processed audio
55
+ cache_entry = {"audio_input": audio_input, "input_values": None, "ssl_logits": None}
56
+ audio_cache.set_without_lock(filename, cache_entry)
57
+ return cache_entry
58
+
59
+ def clean_text(text: str) -> str:
60
+ """
61
+ Remove punctuation from the input string except for special characters
62
+ that are part of a word, such as ' in I'm or - in hard-working.
63
+
64
+ Parameters:
65
+ text (str): Input string to clean.
66
+
67
+ Returns:
68
+ str: Cleaned string with allowed special characters retained.
69
+ """
70
+ # Allow letters, spaces, apostrophes, and hyphens within words
71
+ cleaned_text = re.sub(r'[^\w\s\'-]', '', text) # Remove punctuation except ' and -
72
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Normalize spaces
73
+ return cleaned_text.lower().strip()
inference.py DELETED
@@ -1,214 +0,0 @@
1
- import torch
2
- import librosa
3
- import soundfile as sf
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- import re
6
- import numpy as np
7
- import cmudict
8
-
9
- # Load the processor and model
10
- MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
11
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
12
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
13
- model.eval()
14
-
15
- # Check device availability
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- model.to(device)
18
-
19
- def load_audio(audio_path, target_sr=16000):
20
- """Load an audio file and resample it to 16kHz."""
21
- audio, sr = librosa.load(audio_path, sr=target_sr)
22
- return audio
23
-
24
- # Original ARPAbet to IPA mapping from SoapBox Labs
25
- arpabet_to_ipa = {
26
- "AA": "a", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
27
- "EH": "ɛ", "ER": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
28
- "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "t͡ʃ", "D": "d",
29
- "DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k",
30
- "L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "ɹ",
31
- "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w",
32
- "Y": "j", "Z": "z", "ZH": "ʒ"
33
- }
34
-
35
- # Invert the dictionary to map IPA to ARPAbet
36
- ipa_to_arpabet = {v: k for k, v in arpabet_to_ipa.items()}
37
-
38
- def convert_ipa_to_arpabet(ipa_words):
39
- """
40
- Convert a list of IPA words (strings of concatenated phonemes) to ARPAbet words.
41
-
42
- :param ipa_words: List of IPA words where each word is a string of concatenated phonemes.
43
- :return: List of lists, where each inner list contains ARPAbet phonemes for a word.
44
- """
45
- arpabet_words = []
46
- for word in ipa_words:
47
- # Break the word into phonemes
48
- phonemes = [] # Collect matched phonemes
49
- i = 0
50
- while i < len(word):
51
- matched = False
52
- # Match multi-character IPA phonemes first
53
- for ipa_phoneme in sorted(ipa_to_arpabet.keys(), key=len, reverse=True):
54
- if word[i:].startswith(ipa_phoneme):
55
- phonemes.append(ipa_to_arpabet[ipa_phoneme])
56
- i += len(ipa_phoneme)
57
- matched = True
58
- break
59
- # If no match, add an unknown marker and move forward
60
- if not matched:
61
- phonemes.append("<UNK>")
62
- i += 1
63
- # Append the list of phonemes for the word
64
- arpabet_words.append(phonemes)
65
- return arpabet_words
66
-
67
- def remove_numbers_from_phonemes(phon_list):
68
- """
69
- Remove all numbers from phonemes in a nested list.
70
-
71
- Parameters:
72
- phon_list (list of lists): Nested list of phonemes.
73
-
74
- Returns:
75
- list of lists: Updated nested list with numbers removed from phonemes.
76
- """
77
- cleaned_phon_list = []
78
- for word_phonemes in phon_list:
79
- cleaned_word = [re.sub(r'\d', '', phoneme) for phoneme in word_phonemes]
80
- cleaned_phon_list.append(cleaned_word)
81
- return cleaned_phon_list
82
-
83
- def align_phoneme_sequences(truth_words, uttered_words, gap_penalty=1, substitution_cost=1):
84
- """
85
- Align phoneme sequences separated by words.
86
-
87
- Parameters:
88
- truth_words (list of lists): Ground truth phoneme sequences grouped by words.
89
- uttered_words (list of lists): Uttered phoneme sequences grouped by words.
90
- gap_penalty (int): Penalty for gaps.
91
- substitution_cost (int): Cost for substitutions.
92
-
93
- Returns:
94
- alignment (list of tuples): Aligned phoneme sequences with '-' for gaps.
95
- """
96
- def align_two_sequences(seq1, seq2):
97
- """
98
- Align two sequences using dynamic programming.
99
- """
100
- n = len(seq1)
101
- m = len(seq2)
102
- dp = np.zeros((n + 1, m + 1))
103
-
104
- # Initialize DP table
105
- for i in range(n + 1):
106
- dp[i][0] = i * gap_penalty
107
- for j in range(m + 1):
108
- dp[0][j] = j * gap_penalty
109
-
110
- # Fill DP table
111
- for i in range(1, n + 1):
112
- for j in range(1, m + 1):
113
- match_cost = 0 if seq1[i - 1] == seq2[j - 1] else substitution_cost
114
- dp[i][j] = min(
115
- dp[i - 1][j - 1] + match_cost, # Match or substitution
116
- dp[i - 1][j] + gap_penalty, # Deletion
117
- dp[i][j - 1] + gap_penalty # Insertion
118
- )
119
-
120
- # Traceback to find alignment
121
- alignment_seq1 = []
122
- alignment_seq2 = []
123
- i, j = n, m
124
- while i > 0 or j > 0:
125
- if i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + (0 if seq1[i - 1] == seq2[j - 1] else substitution_cost):
126
- alignment_seq1.append(seq1[i - 1])
127
- alignment_seq2.append(seq2[j - 1])
128
- i -= 1
129
- j -= 1
130
- elif i > 0 and dp[i][j] == dp[i - 1][j] + gap_penalty:
131
- alignment_seq1.append(seq1[i - 1])
132
- alignment_seq2.append('-')
133
- i -= 1
134
- else:
135
- alignment_seq1.append('-')
136
- alignment_seq2.append(seq2[j - 1])
137
- j -= 1
138
-
139
- return alignment_seq1[::-1], alignment_seq2[::-1]
140
-
141
- # Align each word pair
142
- alignment = []
143
- for truth_word, uttered_word in zip(truth_words, uttered_words):
144
- aligned_truth, aligned_uttered = align_two_sequences(truth_word, uttered_word)
145
- alignment.append((aligned_truth, aligned_uttered))
146
-
147
- return alignment
148
-
149
- def generate_phoneme_labels(data):
150
- """
151
- Generate phoneme labels for comparison of expected and uttered phonemes.
152
-
153
- Parameters:
154
- data (list of tuples): Each tuple contains (expected phonemes, uttered phonemes).
155
-
156
- Returns:
157
- list of tuples: Each tuple contains (phonemes, labels).
158
- Phonemes are from the expected list, and labels are binary (0: correct, 1: incorrect).
159
- """
160
- results = []
161
- for expected, uttered in data:
162
- labels = [
163
- 0 if exp == utt else 1
164
- for exp, utt in zip(expected, uttered)
165
- ]
166
- results.append((expected, labels))
167
- return results
168
-
169
- def convert_words_to_phonemes(words, cmu_dict):
170
- phonemes = []
171
- for word in words:
172
- if word in cmu_dict:
173
- phonemes.extend(cmu_dict[word][0]) # Use the first phoneme representation
174
- else:
175
- phonemes.append('<UNK>') # Append 'UNK' for unknown words
176
- return phonemes
177
-
178
- # RUN
179
-
180
- def predict():
181
- cmu = cmudict.dict()
182
-
183
- # Path to test audio file
184
- audio_path = '/content/drive/MyDrive/Test Audio/test5-good.m4a' # Replace with your audio file path
185
-
186
- # Define the script
187
- transcript = "the person that sat on the floor is punched"
188
-
189
- # Load audio and normalize
190
- audio_input = load_audio(audio_path)
191
- input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
192
- input_values = input_values.to(device)
193
-
194
- # Step 3: Perform inference
195
- with torch.no_grad():
196
- logits = model(input_values).logits
197
-
198
- # Step 4: Decode the phonemes
199
- predicted_ids = torch.argmax(logits, dim=-1)
200
- uttured_transcript = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
201
-
202
- # convert uttered ipa into SAMPA (for comparison)
203
- uttured_phons = convert_ipa_to_arpabet(uttured_transcript.split())
204
-
205
- # convert ground truth text into SAMPA (for comparison), and remove (ignore) stress markers (may upgrade to evaluate stress also later)
206
- trans_phons = [convert_words_to_phonemes([word], cmu) for word in transcript.split()]
207
- cleaned_trans_phons = remove_numbers_from_phonemes(trans_phons)
208
-
209
- # Generate labels
210
- alignment = align_phoneme_sequences(cleaned_trans_phons, uttured_phons)
211
- phoneme_labels = generate_phoneme_labels(alignment)
212
-
213
- print(phoneme_labels)
214
- return phoneme_labels