mazesmazes commited on
Commit
14dee2f
·
verified ·
1 Parent(s): 1b9681a

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +77 -50
asr_pipeline.py CHANGED
@@ -84,75 +84,105 @@ class ForcedAligner:
84
  if j > 0:
85
  move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
86
  else:
87
- move = torch.tensor(-float("inf"))
88
 
89
- trellis[t + 1, j] = torch.logaddexp(torch.tensor(stay), move).item()
90
 
91
  return trellis
92
 
93
  @staticmethod
94
  def _backtrack(
95
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
96
- ) -> list[tuple[int, int, int]]:
97
- """Backtrack through trellis to find optimal alignment path.
 
 
 
 
 
98
 
99
  Returns list of (token_id, start_frame, end_frame) for each token.
100
  """
101
  num_frames = emission.size(0)
102
  num_tokens = len(tokens)
103
 
104
- # Trace back from final state
105
- t = num_frames
106
- j = num_tokens
107
- path = [] # Will store (frame, token_index) pairs
108
 
109
- while t > 0 and j >= 0:
110
- # At position (t, j), we need to determine if we got here by:
111
- # 1. Staying at j (emitting blank at frame t-1)
112
- # 2. Moving from j-1 to j (emitting token j-1 at frame t-1)
 
 
 
 
 
113
 
114
- if j == 0:
115
- # Can only stay (no previous token state to come from)
116
- t -= 1
117
- continue
 
 
118
 
119
- # Compare which transition was more likely
 
120
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
121
  move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
122
 
123
- if move_score > stay_score:
124
  # Token j-1 was emitted at frame t-1
125
- path.append((t - 1, j - 1))
126
  j -= 1
127
-
128
  t -= 1
129
 
130
- path.reverse()
131
-
132
- # Convert path to token spans with start/end frames
133
- if not path:
134
- return []
135
 
 
136
  token_spans = []
137
- i = 0
138
- while i < len(path):
139
- frame, token_idx = path[i]
140
- start_frame = frame
141
-
142
- # Find end frame (where this token stops being emitted)
143
- end_frame = frame + 1
144
- while i + 1 < len(path) and path[i + 1][1] == token_idx:
145
- i += 1
146
- end_frame = path[i][0] + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- token_spans.append((tokens[token_idx], start_frame, end_frame))
149
- i += 1
150
 
151
  return token_spans
152
 
153
- # Sub-frame offset to compensate for Wav2Vec2 convolutional look-ahead (in seconds)
154
- # This makes timestamps feel more "natural" by shifting them earlier
155
- OFFSET_COMPENSATION = 0.02 # 40ms
 
156
 
157
  @classmethod
158
  def align(
@@ -162,7 +192,6 @@ class ForcedAligner:
162
  sample_rate: int = 16000,
163
  _language: str = "eng",
164
  _batch_size: int = 16,
165
- offset_compensation: float | None = None,
166
  ) -> list[dict]:
167
  """Align transcript to audio and return word-level timestamps.
168
 
@@ -174,9 +203,6 @@ class ForcedAligner:
174
  sample_rate: Audio sample rate (default 16000)
175
  _language: ISO-639-3 language code (default "eng" for English, unused)
176
  _batch_size: Batch size for alignment model (unused)
177
- offset_compensation: Time offset in seconds to subtract from timestamps
178
- to compensate for Wav2Vec2 look-ahead (default: 0.04s / 40ms).
179
- Set to 0 to disable.
180
 
181
  Returns:
182
  List of dicts with 'word', 'start', 'end' keys
@@ -232,8 +258,9 @@ class ForcedAligner:
232
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
233
  frame_duration = 320 / cls._bundle.sample_rate
234
 
235
- # Apply offset compensation for Wav2Vec2 look-ahead
236
- offset = offset_compensation if offset_compensation is not None else cls.OFFSET_COMPENSATION
 
237
 
238
  # Group aligned tokens into words based on pipe separator
239
  words = text.split()
@@ -246,8 +273,8 @@ class ForcedAligner:
246
  for token_id, start_frame, end_frame in alignment_path:
247
  if token_id == separator_id: # Word separator
248
  if current_word_start is not None and word_idx < len(words):
249
- start_time = max(0.0, current_word_start * frame_duration - offset)
250
- end_time = max(0.0, current_word_end * frame_duration - offset)
251
  word_timestamps.append(
252
  {
253
  "word": words[word_idx],
@@ -265,8 +292,8 @@ class ForcedAligner:
265
 
266
  # Don't forget the last word
267
  if current_word_start is not None and word_idx < len(words):
268
- start_time = max(0.0, current_word_start * frame_duration - offset)
269
- end_time = max(0.0, current_word_end * frame_duration - offset)
270
  word_timestamps.append(
271
  {
272
  "word": words[word_idx],
 
84
  if j > 0:
85
  move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
86
  else:
87
+ move = -float("inf")
88
 
89
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
90
 
91
  return trellis
92
 
93
  @staticmethod
94
  def _backtrack(
95
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
96
+ ) -> list[tuple[int, int, float]]:
97
+ """Backtrack through trellis to find optimal forced monotonic alignment.
98
+
99
+ Guarantees:
100
+ - All tokens are emitted exactly once
101
+ - Strictly monotonic: each token's frames come after previous token's
102
+ - No frame skipping or token teleporting
103
 
104
  Returns list of (token_id, start_frame, end_frame) for each token.
105
  """
106
  num_frames = emission.size(0)
107
  num_tokens = len(tokens)
108
 
109
+ if num_tokens == 0:
110
+ return []
 
 
111
 
112
+ # Find the best ending point (should be at num_tokens)
113
+ # But verify trellis reached a valid state
114
+ if trellis[num_frames, num_tokens] == -float("inf"):
115
+ # Alignment failed - fall back to uniform distribution
116
+ frames_per_token = num_frames / num_tokens
117
+ return [
118
+ (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
119
+ for i in range(num_tokens)
120
+ ]
121
 
122
+ # Backtrack: find where each token transition occurred
123
+ # path[i] = frame where token i was first emitted
124
+ token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
125
+
126
+ t = num_frames
127
+ j = num_tokens
128
 
129
+ while t > 0 and j > 0:
130
+ # Check: did we transition from j-1 to j at frame t-1?
131
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
132
  move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
133
 
134
+ if move_score >= stay_score:
135
  # Token j-1 was emitted at frame t-1
136
+ token_frames[j - 1].insert(0, t - 1)
137
  j -= 1
138
+ # Always decrement time (monotonic)
139
  t -= 1
140
 
141
+ # Handle any remaining tokens at the start (edge case)
142
+ while j > 0:
143
+ token_frames[j - 1].insert(0, 0)
144
+ j -= 1
 
145
 
146
+ # Convert to spans with sub-frame interpolation
147
  token_spans = []
148
+ for token_idx, frames in enumerate(token_frames):
149
+ if not frames:
150
+ # Token never emitted - assign minimal span after previous
151
+ if token_spans:
152
+ prev_end = token_spans[-1][2]
153
+ frames = [int(prev_end)]
154
+ else:
155
+ frames = [0]
156
+
157
+ token_id = tokens[token_idx]
158
+ frame_probs = emission[frames, token_id]
159
+ peak_idx = int(torch.argmax(frame_probs).item())
160
+ peak_frame = frames[peak_idx]
161
+
162
+ # Sub-frame interpolation using quadratic fit around peak
163
+ if len(frames) >= 3 and 0 < peak_idx < len(frames) - 1:
164
+ y0 = frame_probs[peak_idx - 1].item()
165
+ y1 = frame_probs[peak_idx].item()
166
+ y2 = frame_probs[peak_idx + 1].item()
167
+
168
+ denom = y0 - 2 * y1 + y2
169
+ if abs(denom) > 1e-10:
170
+ offset = 0.5 * (y0 - y2) / denom
171
+ offset = max(-0.5, min(0.5, offset))
172
+ else:
173
+ offset = 0.0
174
+ refined_frame = peak_frame + offset
175
+ else:
176
+ refined_frame = float(peak_frame)
177
 
178
+ token_spans.append((token_id, refined_frame, refined_frame + 1.0))
 
179
 
180
  return token_spans
181
 
182
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
183
+ # Calibrated on librispeech-alignments dataset
184
+ START_OFFSET = 0.06 # Subtract from start times (shift earlier)
185
+ END_OFFSET = -0.03 # Add to end times (shift later)
186
 
187
  @classmethod
188
  def align(
 
192
  sample_rate: int = 16000,
193
  _language: str = "eng",
194
  _batch_size: int = 16,
 
195
  ) -> list[dict]:
196
  """Align transcript to audio and return word-level timestamps.
197
 
 
203
  sample_rate: Audio sample rate (default 16000)
204
  _language: ISO-639-3 language code (default "eng" for English, unused)
205
  _batch_size: Batch size for alignment model (unused)
 
 
 
206
 
207
  Returns:
208
  List of dicts with 'word', 'start', 'end' keys
 
258
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
259
  frame_duration = 320 / cls._bundle.sample_rate
260
 
261
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
262
+ start_offset = cls.START_OFFSET
263
+ end_offset = cls.END_OFFSET
264
 
265
  # Group aligned tokens into words based on pipe separator
266
  words = text.split()
 
273
  for token_id, start_frame, end_frame in alignment_path:
274
  if token_id == separator_id: # Word separator
275
  if current_word_start is not None and word_idx < len(words):
276
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
277
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
278
  word_timestamps.append(
279
  {
280
  "word": words[word_idx],
 
292
 
293
  # Don't forget the last word
294
  if current_word_start is not None and word_idx < len(words):
295
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
296
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
297
  word_timestamps.append(
298
  {
299
  "word": words[word_idx],