bachtom125 commited on
Commit
385e141
·
1 Parent(s): 58446fa

first commit

Browse files
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Python 3",
3
+ // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
4
+ "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye",
5
+ "customizations": {
6
+ "codespaces": {
7
+ "openFiles": [
8
+ "README.md",
9
+ "app.py"
10
+ ]
11
+ },
12
+ "vscode": {
13
+ "settings": {},
14
+ "extensions": [
15
+ "ms-python.python",
16
+ "ms-python.vscode-pylance"
17
+ ]
18
+ }
19
+ },
20
+ "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
+ "postAttachCommand": {
22
+ "server": "streamlit run app.py --server.enableCORS false --server.enableXsrfProtection false"
23
+ },
24
+ "portsAttributes": {
25
+ "8501": {
26
+ "label": "Application",
27
+ "onAutoForward": "openPreview"
28
+ }
29
+ },
30
+ "forwardPorts": [
31
+ 8501
32
+ ]
33
+ }
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore unnecessary files
2
+ .git
3
+ __pycache__
4
+ *.pyc
5
+ *.pyo
6
+ *.log
7
+ *.tmp
8
+ *.zip
9
+ *.tar.gz
10
+ Datasets/
11
+ .venv/
12
+ Audios/
.gitignore ADDED
Binary file (96 Bytes). View file
 
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (for librosa and other packages)
8
+ RUN apt-get update && apt-get install -y \
9
+ libsndfile1 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements.txt first to leverage Docker's caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy the rest of the application code
19
+ COPY . .
20
+
21
+ # Expose port 10000 (or whatever port your app uses)
22
+ EXPOSE 10000
23
+
24
+ # Command to run the application using Uvicorn
25
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "10000"]
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, UploadFile, Form, HTTPException
3
+ from fastapi.responses import JSONResponse
4
+ import uvicorn
5
+ from typing import List
6
+ import torch
7
+ import librosa
8
+ import soundfile as sf
9
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
10
+ import re
11
+ import numpy as np
12
+ import cmudict
13
+ from io import BytesIO
14
+ import os
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ cmu = cmudict.dict()
20
+
21
+ # Initialize FastAPI app
22
+ app = FastAPI()
23
+
24
+ # Load the processor and model
25
+ MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
26
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
27
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
28
+ model.eval()
29
+
30
+ # Check device availability
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ model.to(device)
33
+
34
+ def load_audio(audio_path, target_sr=16000):
35
+ """Load an audio file and resample it to 16kHz."""
36
+ audio, sr = librosa.load(audio_path, sr=target_sr)
37
+ return audio
38
+
39
+ # Original ARPAbet to IPA mapping from SoapBox Labs
40
+ arpabet_to_ipa = {
41
+ "AA": "a", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
42
+ "EH": "ɛ", "ER": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
43
+ "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "t͡ʃ", "D": "d",
44
+ "DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k",
45
+ "L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "ɹ",
46
+ "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w",
47
+ "Y": "j", "Z": "z", "ZH": "ʒ"
48
+ }
49
+
50
+ # Invert the dictionary to map IPA to ARPAbet
51
+ ipa_to_arpabet = {v: k for k, v in arpabet_to_ipa.items()}
52
+
53
+ def convert_ipa_to_arpabet(ipa_words):
54
+ """
55
+ Convert a list of IPA words (strings of concatenated phonemes) to ARPAbet words.
56
+
57
+ :param ipa_words: List of IPA words where each word is a string of concatenated phonemes.
58
+ :return: List of lists, where each inner list contains ARPAbet phonemes for a word.
59
+ """
60
+ arpabet_words = []
61
+ for word in ipa_words:
62
+ # Break the word into phonemes
63
+ phonemes = [] # Collect matched phonemes
64
+ i = 0
65
+ while i < len(word):
66
+ matched = False
67
+ # Match multi-character IPA phonemes first
68
+ for ipa_phoneme in sorted(ipa_to_arpabet.keys(), key=len, reverse=True):
69
+ if word[i:].startswith(ipa_phoneme):
70
+ phonemes.append(ipa_to_arpabet[ipa_phoneme])
71
+ i += len(ipa_phoneme)
72
+ matched = True
73
+ break
74
+ # If no match, add an unknown marker and move forward
75
+ if not matched:
76
+ phonemes.append("<UNK>")
77
+ i += 1
78
+ # Append the list of phonemes for the word
79
+ arpabet_words.append(phonemes)
80
+ return arpabet_words
81
+
82
+ def remove_numbers_from_phonemes(phon_list):
83
+ """
84
+ Remove all numbers from phonemes in a nested list.
85
+
86
+ Parameters:
87
+ phon_list (list of lists): Nested list of phonemes.
88
+
89
+ Returns:
90
+ list of lists: Updated nested list with numbers removed from phonemes.
91
+ """
92
+ cleaned_phon_list = []
93
+ for word_phonemes in phon_list:
94
+ cleaned_word = [re.sub(r'\d', '', phoneme) for phoneme in word_phonemes]
95
+ cleaned_phon_list.append(cleaned_word)
96
+ return cleaned_phon_list
97
+
98
+ def align_phoneme_sequences(truth_words, uttered_words, gap_penalty=1, substitution_cost=1):
99
+ """
100
+ Align phoneme sequences separated by words.
101
+
102
+ Parameters:
103
+ truth_words (list of lists): Ground truth phoneme sequences grouped by words.
104
+ uttered_words (list of lists): Uttered phoneme sequences grouped by words.
105
+ gap_penalty (int): Penalty for gaps.
106
+ substitution_cost (int): Cost for substitutions.
107
+
108
+ Returns:
109
+ alignment (list of tuples): Aligned phoneme sequences with '-' for gaps.
110
+ """
111
+ def align_two_sequences(seq1, seq2):
112
+ """
113
+ Align two sequences using dynamic programming.
114
+ """
115
+ n = len(seq1)
116
+ m = len(seq2)
117
+ dp = np.zeros((n + 1, m + 1))
118
+
119
+ # Initialize DP table
120
+ for i in range(n + 1):
121
+ dp[i][0] = i * gap_penalty
122
+ for j in range(m + 1):
123
+ dp[0][j] = j * gap_penalty
124
+
125
+ # Fill DP table
126
+ for i in range(1, n + 1):
127
+ for j in range(1, m + 1):
128
+ match_cost = 0 if seq1[i - 1] == seq2[j - 1] else substitution_cost
129
+ dp[i][j] = min(
130
+ dp[i - 1][j - 1] + match_cost, # Match or substitution
131
+ dp[i - 1][j] + gap_penalty, # Deletion
132
+ dp[i][j - 1] + gap_penalty # Insertion
133
+ )
134
+
135
+ # Traceback to find alignment
136
+ alignment_seq1 = []
137
+ alignment_seq2 = []
138
+ i, j = n, m
139
+ while i > 0 or j > 0:
140
+ if i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + (0 if seq1[i - 1] == seq2[j - 1] else substitution_cost):
141
+ alignment_seq1.append(seq1[i - 1])
142
+ alignment_seq2.append(seq2[j - 1])
143
+ i -= 1
144
+ j -= 1
145
+ elif i > 0 and dp[i][j] == dp[i - 1][j] + gap_penalty:
146
+ alignment_seq1.append(seq1[i - 1])
147
+ alignment_seq2.append('-')
148
+ i -= 1
149
+ else:
150
+ alignment_seq1.append('-')
151
+ alignment_seq2.append(seq2[j - 1])
152
+ j -= 1
153
+
154
+ return alignment_seq1[::-1], alignment_seq2[::-1]
155
+
156
+ # Align each word pair
157
+ alignment = []
158
+ for truth_word, uttered_word in zip(truth_words, uttered_words):
159
+ aligned_truth, aligned_uttered = align_two_sequences(truth_word, uttered_word)
160
+ alignment.append((aligned_truth, aligned_uttered))
161
+
162
+ return alignment
163
+
164
+ def generate_phoneme_labels(data):
165
+ """
166
+ Generate phoneme labels for comparison of expected and uttered phonemes.
167
+
168
+ Parameters:
169
+ data (list of tuples): Each tuple contains (expected phonemes, uttered phonemes).
170
+
171
+ Returns:
172
+ list of tuples: Each tuple contains (phonemes, labels).
173
+ Phonemes are from the expected list, and labels are binary (0: correct, 1: incorrect).
174
+ """
175
+ results = []
176
+ for expected, uttered in data:
177
+ labels = [
178
+ 0 if exp == utt else 1
179
+ for exp, utt in zip(expected, uttered)
180
+ ]
181
+ results.append((expected, labels))
182
+ return results
183
+
184
+ def convert_words_to_phonemes(words, cmu_dict):
185
+ phonemes = []
186
+ for word in words:
187
+ if word in cmu_dict:
188
+ phonemes.extend(cmu_dict[word][0]) # Use the first phoneme representation
189
+ else:
190
+ phonemes.append('<UNK>') # Append 'UNK' for unknown words
191
+ return phonemes
192
+
193
+ # health check
194
+ @app.get("/")
195
+ def home():
196
+ return "Healthy bro!"
197
+
198
+ # taking in both audio and transcript from the user
199
+ @app.post("/predict")
200
+ async def predict(audio: UploadFile, transcript: str = Form(...)):
201
+ """
202
+ Predict phoneme labels from uploaded audio and provided transcript.
203
+
204
+ Args:
205
+ audio (UploadFile): Uploaded audio file (WAV/MP3).
206
+ transcript (str): Ground truth transcript.
207
+
208
+ Returns:
209
+ JSONResponse: Contains phoneme labels.
210
+ """
211
+ logging.info("Received prediction request!")
212
+
213
+ # Validate file extension
214
+ allowed_extensions = {"wav", "mp3"}
215
+ filename = audio.filename.lower()
216
+
217
+ if not filename.endswith(tuple(allowed_extensions)):
218
+ raise HTTPException(
219
+ status_code=400,
220
+ detail="Invalid file type. Only WAV and MP3 files are supported.",
221
+ )
222
+
223
+ # Load and preprocess the audio
224
+ try:
225
+ audio_bytes = BytesIO(await audio.read())
226
+ audio_input, sr = librosa.load(audio_bytes, sr=16000)
227
+ input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
228
+ input_values = input_values.to(device)
229
+
230
+ # Perform inference
231
+ with torch.no_grad():
232
+ logits = model(input_values).logits
233
+
234
+ # Decode the phonemes
235
+ predicted_ids = torch.argmax(logits, dim=-1)
236
+ uttured_transcript = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
237
+
238
+ # Convert uttered IPA into SAMPA (for comparison)
239
+ uttured_phons = convert_ipa_to_arpabet(uttured_transcript.split())
240
+
241
+ # Convert ground truth text into SAMPA (for comparison) and remove stress markers
242
+ trans_phons = [convert_words_to_phonemes([word], cmu) for word in transcript.split()]
243
+ cleaned_trans_phons = remove_numbers_from_phonemes(trans_phons)
244
+
245
+ # Generate labels
246
+ alignment = align_phoneme_sequences(cleaned_trans_phons, uttured_phons)
247
+ phoneme_labels = generate_phoneme_labels(alignment)
248
+
249
+ return JSONResponse(content={"phoneme_labels": phoneme_labels})
250
+
251
+ except Exception as e:
252
+ logging.error(f"Error during prediction: {e}")
253
+ raise HTTPException(status_code=500, detail="An error occurred during processing.")
254
+
255
+ if __name__ == '__main__':
256
+ port = os.environ.get("PORT", 10000) # Default to 10000 if PORT is not set
257
+ logging.info(f"Starting server on PORT {port}")
258
+ uvicorn.run("app:app", host="0.0.0.0", port=int(port), log_level="info")
inference.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import soundfile as sf
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+ import re
6
+ import numpy as np
7
+ import cmudict
8
+
9
+ # Load the processor and model
10
+ MODEL_NAME = "mrrubino/wav2vec2-large-xlsr-53-l2-arctic-phoneme" # wav2vec based phoneme trascriber trained on L2-ARTIC
11
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
12
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
13
+ model.eval()
14
+
15
+ # Check device availability
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model.to(device)
18
+
19
+ def load_audio(audio_path, target_sr=16000):
20
+ """Load an audio file and resample it to 16kHz."""
21
+ audio, sr = librosa.load(audio_path, sr=target_sr)
22
+ return audio
23
+
24
+ # Original ARPAbet to IPA mapping from SoapBox Labs
25
+ arpabet_to_ipa = {
26
+ "AA": "a", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
27
+ "EH": "ɛ", "ER": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
28
+ "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "t͡ʃ", "D": "d",
29
+ "DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k",
30
+ "L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "ɹ",
31
+ "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w",
32
+ "Y": "j", "Z": "z", "ZH": "ʒ"
33
+ }
34
+
35
+ # Invert the dictionary to map IPA to ARPAbet
36
+ ipa_to_arpabet = {v: k for k, v in arpabet_to_ipa.items()}
37
+
38
+ def convert_ipa_to_arpabet(ipa_words):
39
+ """
40
+ Convert a list of IPA words (strings of concatenated phonemes) to ARPAbet words.
41
+
42
+ :param ipa_words: List of IPA words where each word is a string of concatenated phonemes.
43
+ :return: List of lists, where each inner list contains ARPAbet phonemes for a word.
44
+ """
45
+ arpabet_words = []
46
+ for word in ipa_words:
47
+ # Break the word into phonemes
48
+ phonemes = [] # Collect matched phonemes
49
+ i = 0
50
+ while i < len(word):
51
+ matched = False
52
+ # Match multi-character IPA phonemes first
53
+ for ipa_phoneme in sorted(ipa_to_arpabet.keys(), key=len, reverse=True):
54
+ if word[i:].startswith(ipa_phoneme):
55
+ phonemes.append(ipa_to_arpabet[ipa_phoneme])
56
+ i += len(ipa_phoneme)
57
+ matched = True
58
+ break
59
+ # If no match, add an unknown marker and move forward
60
+ if not matched:
61
+ phonemes.append("<UNK>")
62
+ i += 1
63
+ # Append the list of phonemes for the word
64
+ arpabet_words.append(phonemes)
65
+ return arpabet_words
66
+
67
+ def remove_numbers_from_phonemes(phon_list):
68
+ """
69
+ Remove all numbers from phonemes in a nested list.
70
+
71
+ Parameters:
72
+ phon_list (list of lists): Nested list of phonemes.
73
+
74
+ Returns:
75
+ list of lists: Updated nested list with numbers removed from phonemes.
76
+ """
77
+ cleaned_phon_list = []
78
+ for word_phonemes in phon_list:
79
+ cleaned_word = [re.sub(r'\d', '', phoneme) for phoneme in word_phonemes]
80
+ cleaned_phon_list.append(cleaned_word)
81
+ return cleaned_phon_list
82
+
83
+ def align_phoneme_sequences(truth_words, uttered_words, gap_penalty=1, substitution_cost=1):
84
+ """
85
+ Align phoneme sequences separated by words.
86
+
87
+ Parameters:
88
+ truth_words (list of lists): Ground truth phoneme sequences grouped by words.
89
+ uttered_words (list of lists): Uttered phoneme sequences grouped by words.
90
+ gap_penalty (int): Penalty for gaps.
91
+ substitution_cost (int): Cost for substitutions.
92
+
93
+ Returns:
94
+ alignment (list of tuples): Aligned phoneme sequences with '-' for gaps.
95
+ """
96
+ def align_two_sequences(seq1, seq2):
97
+ """
98
+ Align two sequences using dynamic programming.
99
+ """
100
+ n = len(seq1)
101
+ m = len(seq2)
102
+ dp = np.zeros((n + 1, m + 1))
103
+
104
+ # Initialize DP table
105
+ for i in range(n + 1):
106
+ dp[i][0] = i * gap_penalty
107
+ for j in range(m + 1):
108
+ dp[0][j] = j * gap_penalty
109
+
110
+ # Fill DP table
111
+ for i in range(1, n + 1):
112
+ for j in range(1, m + 1):
113
+ match_cost = 0 if seq1[i - 1] == seq2[j - 1] else substitution_cost
114
+ dp[i][j] = min(
115
+ dp[i - 1][j - 1] + match_cost, # Match or substitution
116
+ dp[i - 1][j] + gap_penalty, # Deletion
117
+ dp[i][j - 1] + gap_penalty # Insertion
118
+ )
119
+
120
+ # Traceback to find alignment
121
+ alignment_seq1 = []
122
+ alignment_seq2 = []
123
+ i, j = n, m
124
+ while i > 0 or j > 0:
125
+ if i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + (0 if seq1[i - 1] == seq2[j - 1] else substitution_cost):
126
+ alignment_seq1.append(seq1[i - 1])
127
+ alignment_seq2.append(seq2[j - 1])
128
+ i -= 1
129
+ j -= 1
130
+ elif i > 0 and dp[i][j] == dp[i - 1][j] + gap_penalty:
131
+ alignment_seq1.append(seq1[i - 1])
132
+ alignment_seq2.append('-')
133
+ i -= 1
134
+ else:
135
+ alignment_seq1.append('-')
136
+ alignment_seq2.append(seq2[j - 1])
137
+ j -= 1
138
+
139
+ return alignment_seq1[::-1], alignment_seq2[::-1]
140
+
141
+ # Align each word pair
142
+ alignment = []
143
+ for truth_word, uttered_word in zip(truth_words, uttered_words):
144
+ aligned_truth, aligned_uttered = align_two_sequences(truth_word, uttered_word)
145
+ alignment.append((aligned_truth, aligned_uttered))
146
+
147
+ return alignment
148
+
149
+ def generate_phoneme_labels(data):
150
+ """
151
+ Generate phoneme labels for comparison of expected and uttered phonemes.
152
+
153
+ Parameters:
154
+ data (list of tuples): Each tuple contains (expected phonemes, uttered phonemes).
155
+
156
+ Returns:
157
+ list of tuples: Each tuple contains (phonemes, labels).
158
+ Phonemes are from the expected list, and labels are binary (0: correct, 1: incorrect).
159
+ """
160
+ results = []
161
+ for expected, uttered in data:
162
+ labels = [
163
+ 0 if exp == utt else 1
164
+ for exp, utt in zip(expected, uttered)
165
+ ]
166
+ results.append((expected, labels))
167
+ return results
168
+
169
+ def convert_words_to_phonemes(words, cmu_dict):
170
+ phonemes = []
171
+ for word in words:
172
+ if word in cmu_dict:
173
+ phonemes.extend(cmu_dict[word][0]) # Use the first phoneme representation
174
+ else:
175
+ phonemes.append('<UNK>') # Append 'UNK' for unknown words
176
+ return phonemes
177
+
178
+ # RUN
179
+
180
+ def predict():
181
+ cmu = cmudict.dict()
182
+
183
+ # Path to test audio file
184
+ audio_path = '/content/drive/MyDrive/Test Audio/test5-good.m4a' # Replace with your audio file path
185
+
186
+ # Define the script
187
+ transcript = "the person that sat on the floor is punched"
188
+
189
+ # Load audio and normalize
190
+ audio_input = load_audio(audio_path)
191
+ input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
192
+ input_values = input_values.to(device)
193
+
194
+ # Step 3: Perform inference
195
+ with torch.no_grad():
196
+ logits = model(input_values).logits
197
+
198
+ # Step 4: Decode the phonemes
199
+ predicted_ids = torch.argmax(logits, dim=-1)
200
+ uttured_transcript = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
201
+
202
+ # convert uttered ipa into SAMPA (for comparison)
203
+ uttured_phons = convert_ipa_to_arpabet(uttured_transcript.split())
204
+
205
+ # convert ground truth text into SAMPA (for comparison), and remove (ignore) stress markers (may upgrade to evaluate stress also later)
206
+ trans_phons = [convert_words_to_phonemes([word], cmu) for word in transcript.split()]
207
+ cleaned_trans_phons = remove_numbers_from_phonemes(trans_phons)
208
+
209
+ # Generate labels
210
+ alignment = align_phoneme_sequences(cleaned_trans_phons, uttured_phons)
211
+ phoneme_labels = generate_phoneme_labels(alignment)
212
+
213
+ print(phoneme_labels)
214
+ return phoneme_labels
notebook-inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
Binary file (542 Bytes). View file