Update custom model files, README, and requirements
Browse files- asr_pipeline.py +77 -50
asr_pipeline.py
CHANGED
|
@@ -84,75 +84,105 @@ class ForcedAligner:
|
|
| 84 |
if j > 0:
|
| 85 |
move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
|
| 86 |
else:
|
| 87 |
-
move =
|
| 88 |
|
| 89 |
-
trellis[t + 1, j] =
|
| 90 |
|
| 91 |
return trellis
|
| 92 |
|
| 93 |
@staticmethod
|
| 94 |
def _backtrack(
|
| 95 |
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 96 |
-
) -> list[tuple[int, int,
|
| 97 |
-
"""Backtrack through trellis to find optimal alignment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
Returns list of (token_id, start_frame, end_frame) for each token.
|
| 100 |
"""
|
| 101 |
num_frames = emission.size(0)
|
| 102 |
num_tokens = len(tokens)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
j = num_tokens
|
| 107 |
-
path = [] # Will store (frame, token_index) pairs
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
|
| 121 |
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
| 122 |
|
| 123 |
-
if move_score
|
| 124 |
# Token j-1 was emitted at frame t-1
|
| 125 |
-
|
| 126 |
j -= 1
|
| 127 |
-
|
| 128 |
t -= 1
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
return []
|
| 135 |
|
|
|
|
| 136 |
token_spans = []
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
token_spans.append((
|
| 149 |
-
i += 1
|
| 150 |
|
| 151 |
return token_spans
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
#
|
| 155 |
-
|
|
|
|
| 156 |
|
| 157 |
@classmethod
|
| 158 |
def align(
|
|
@@ -162,7 +192,6 @@ class ForcedAligner:
|
|
| 162 |
sample_rate: int = 16000,
|
| 163 |
_language: str = "eng",
|
| 164 |
_batch_size: int = 16,
|
| 165 |
-
offset_compensation: float | None = None,
|
| 166 |
) -> list[dict]:
|
| 167 |
"""Align transcript to audio and return word-level timestamps.
|
| 168 |
|
|
@@ -174,9 +203,6 @@ class ForcedAligner:
|
|
| 174 |
sample_rate: Audio sample rate (default 16000)
|
| 175 |
_language: ISO-639-3 language code (default "eng" for English, unused)
|
| 176 |
_batch_size: Batch size for alignment model (unused)
|
| 177 |
-
offset_compensation: Time offset in seconds to subtract from timestamps
|
| 178 |
-
to compensate for Wav2Vec2 look-ahead (default: 0.04s / 40ms).
|
| 179 |
-
Set to 0 to disable.
|
| 180 |
|
| 181 |
Returns:
|
| 182 |
List of dicts with 'word', 'start', 'end' keys
|
|
@@ -232,8 +258,9 @@ class ForcedAligner:
|
|
| 232 |
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 233 |
frame_duration = 320 / cls._bundle.sample_rate
|
| 234 |
|
| 235 |
-
# Apply offset compensation for Wav2Vec2
|
| 236 |
-
|
|
|
|
| 237 |
|
| 238 |
# Group aligned tokens into words based on pipe separator
|
| 239 |
words = text.split()
|
|
@@ -246,8 +273,8 @@ class ForcedAligner:
|
|
| 246 |
for token_id, start_frame, end_frame in alignment_path:
|
| 247 |
if token_id == separator_id: # Word separator
|
| 248 |
if current_word_start is not None and word_idx < len(words):
|
| 249 |
-
start_time = max(0.0, current_word_start * frame_duration -
|
| 250 |
-
end_time = max(0.0, current_word_end * frame_duration -
|
| 251 |
word_timestamps.append(
|
| 252 |
{
|
| 253 |
"word": words[word_idx],
|
|
@@ -265,8 +292,8 @@ class ForcedAligner:
|
|
| 265 |
|
| 266 |
# Don't forget the last word
|
| 267 |
if current_word_start is not None and word_idx < len(words):
|
| 268 |
-
start_time = max(0.0, current_word_start * frame_duration -
|
| 269 |
-
end_time = max(0.0, current_word_end * frame_duration -
|
| 270 |
word_timestamps.append(
|
| 271 |
{
|
| 272 |
"word": words[word_idx],
|
|
|
|
| 84 |
if j > 0:
|
| 85 |
move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
|
| 86 |
else:
|
| 87 |
+
move = -float("inf")
|
| 88 |
|
| 89 |
+
trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
|
| 90 |
|
| 91 |
return trellis
|
| 92 |
|
| 93 |
@staticmethod
|
| 94 |
def _backtrack(
|
| 95 |
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 96 |
+
) -> list[tuple[int, int, float]]:
|
| 97 |
+
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 98 |
+
|
| 99 |
+
Guarantees:
|
| 100 |
+
- All tokens are emitted exactly once
|
| 101 |
+
- Strictly monotonic: each token's frames come after previous token's
|
| 102 |
+
- No frame skipping or token teleporting
|
| 103 |
|
| 104 |
Returns list of (token_id, start_frame, end_frame) for each token.
|
| 105 |
"""
|
| 106 |
num_frames = emission.size(0)
|
| 107 |
num_tokens = len(tokens)
|
| 108 |
|
| 109 |
+
if num_tokens == 0:
|
| 110 |
+
return []
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# Find the best ending point (should be at num_tokens)
|
| 113 |
+
# But verify trellis reached a valid state
|
| 114 |
+
if trellis[num_frames, num_tokens] == -float("inf"):
|
| 115 |
+
# Alignment failed - fall back to uniform distribution
|
| 116 |
+
frames_per_token = num_frames / num_tokens
|
| 117 |
+
return [
|
| 118 |
+
(tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
|
| 119 |
+
for i in range(num_tokens)
|
| 120 |
+
]
|
| 121 |
|
| 122 |
+
# Backtrack: find where each token transition occurred
|
| 123 |
+
# path[i] = frame where token i was first emitted
|
| 124 |
+
token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
|
| 125 |
+
|
| 126 |
+
t = num_frames
|
| 127 |
+
j = num_tokens
|
| 128 |
|
| 129 |
+
while t > 0 and j > 0:
|
| 130 |
+
# Check: did we transition from j-1 to j at frame t-1?
|
| 131 |
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
|
| 132 |
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
| 133 |
|
| 134 |
+
if move_score >= stay_score:
|
| 135 |
# Token j-1 was emitted at frame t-1
|
| 136 |
+
token_frames[j - 1].insert(0, t - 1)
|
| 137 |
j -= 1
|
| 138 |
+
# Always decrement time (monotonic)
|
| 139 |
t -= 1
|
| 140 |
|
| 141 |
+
# Handle any remaining tokens at the start (edge case)
|
| 142 |
+
while j > 0:
|
| 143 |
+
token_frames[j - 1].insert(0, 0)
|
| 144 |
+
j -= 1
|
|
|
|
| 145 |
|
| 146 |
+
# Convert to spans with sub-frame interpolation
|
| 147 |
token_spans = []
|
| 148 |
+
for token_idx, frames in enumerate(token_frames):
|
| 149 |
+
if not frames:
|
| 150 |
+
# Token never emitted - assign minimal span after previous
|
| 151 |
+
if token_spans:
|
| 152 |
+
prev_end = token_spans[-1][2]
|
| 153 |
+
frames = [int(prev_end)]
|
| 154 |
+
else:
|
| 155 |
+
frames = [0]
|
| 156 |
+
|
| 157 |
+
token_id = tokens[token_idx]
|
| 158 |
+
frame_probs = emission[frames, token_id]
|
| 159 |
+
peak_idx = int(torch.argmax(frame_probs).item())
|
| 160 |
+
peak_frame = frames[peak_idx]
|
| 161 |
+
|
| 162 |
+
# Sub-frame interpolation using quadratic fit around peak
|
| 163 |
+
if len(frames) >= 3 and 0 < peak_idx < len(frames) - 1:
|
| 164 |
+
y0 = frame_probs[peak_idx - 1].item()
|
| 165 |
+
y1 = frame_probs[peak_idx].item()
|
| 166 |
+
y2 = frame_probs[peak_idx + 1].item()
|
| 167 |
+
|
| 168 |
+
denom = y0 - 2 * y1 + y2
|
| 169 |
+
if abs(denom) > 1e-10:
|
| 170 |
+
offset = 0.5 * (y0 - y2) / denom
|
| 171 |
+
offset = max(-0.5, min(0.5, offset))
|
| 172 |
+
else:
|
| 173 |
+
offset = 0.0
|
| 174 |
+
refined_frame = peak_frame + offset
|
| 175 |
+
else:
|
| 176 |
+
refined_frame = float(peak_frame)
|
| 177 |
|
| 178 |
+
token_spans.append((token_id, refined_frame, refined_frame + 1.0))
|
|
|
|
| 179 |
|
| 180 |
return token_spans
|
| 181 |
|
| 182 |
+
# Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
|
| 183 |
+
# Calibrated on librispeech-alignments dataset
|
| 184 |
+
START_OFFSET = 0.06 # Subtract from start times (shift earlier)
|
| 185 |
+
END_OFFSET = -0.03 # Add to end times (shift later)
|
| 186 |
|
| 187 |
@classmethod
|
| 188 |
def align(
|
|
|
|
| 192 |
sample_rate: int = 16000,
|
| 193 |
_language: str = "eng",
|
| 194 |
_batch_size: int = 16,
|
|
|
|
| 195 |
) -> list[dict]:
|
| 196 |
"""Align transcript to audio and return word-level timestamps.
|
| 197 |
|
|
|
|
| 203 |
sample_rate: Audio sample rate (default 16000)
|
| 204 |
_language: ISO-639-3 language code (default "eng" for English, unused)
|
| 205 |
_batch_size: Batch size for alignment model (unused)
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
Returns:
|
| 208 |
List of dicts with 'word', 'start', 'end' keys
|
|
|
|
| 258 |
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 259 |
frame_duration = 320 / cls._bundle.sample_rate
|
| 260 |
|
| 261 |
+
# Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
|
| 262 |
+
start_offset = cls.START_OFFSET
|
| 263 |
+
end_offset = cls.END_OFFSET
|
| 264 |
|
| 265 |
# Group aligned tokens into words based on pipe separator
|
| 266 |
words = text.split()
|
|
|
|
| 273 |
for token_id, start_frame, end_frame in alignment_path:
|
| 274 |
if token_id == separator_id: # Word separator
|
| 275 |
if current_word_start is not None and word_idx < len(words):
|
| 276 |
+
start_time = max(0.0, current_word_start * frame_duration - start_offset)
|
| 277 |
+
end_time = max(0.0, current_word_end * frame_duration - end_offset)
|
| 278 |
word_timestamps.append(
|
| 279 |
{
|
| 280 |
"word": words[word_idx],
|
|
|
|
| 292 |
|
| 293 |
# Don't forget the last word
|
| 294 |
if current_word_start is not None and word_idx < len(words):
|
| 295 |
+
start_time = max(0.0, current_word_start * frame_duration - start_offset)
|
| 296 |
+
end_time = max(0.0, current_word_end * frame_duration - end_offset)
|
| 297 |
word_timestamps.append(
|
| 298 |
{
|
| 299 |
"word": words[word_idx],
|