mazesmazes commited on
Commit
12ada82
·
verified ·
1 Parent(s): 85a64d2

Training in progress - step 15000

Browse files
Files changed (2) hide show
  1. asr_modeling.py +2 -160
  2. asr_pipeline.py +3 -116
asr_modeling.py CHANGED
@@ -1,8 +1,5 @@
1
  from pathlib import Path
2
- from typing import Optional, Union, Generator, NamedTuple
3
-
4
- import threading
5
- from concurrent import futures
6
 
7
  import torch
8
  import torch.nn as nn
@@ -14,7 +11,6 @@ from transformers import (
14
  AutoTokenizer,
15
  PreTrainedModel,
16
  Wav2Vec2FeatureExtractor,
17
- TextIteratorStreamer,
18
  )
19
  from transformers.generation.utils import (
20
  GenerateBeamDecoderOnlyOutput,
@@ -29,17 +25,6 @@ except ImportError:
29
  from asr_config import ASRConfig # type: ignore[no-redef]
30
 
31
 
32
- class StreamChunk(NamedTuple):
33
- """A chunk of streaming transcription text."""
34
- text: str
35
-
36
-
37
- class StreamStats(NamedTuple):
38
- """Statistics about the streaming inference."""
39
- input_tokens: int
40
- output_tokens: int
41
-
42
-
43
  class SwiGLU(nn.Module):
44
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
45
  super().__init__()
@@ -133,12 +118,8 @@ class ASRModel(PreTrainedModel):
133
  return WhisperFeatureExtractor.from_pretrained(
134
  audio_model_id,
135
  feature_size=num_mel_bins,
136
- do_normalize=True,
137
  )
138
- return Wav2Vec2FeatureExtractor.from_pretrained(
139
- audio_model_id,
140
- do_normalize=True,
141
- )
142
 
143
  @classmethod
144
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
@@ -706,145 +687,6 @@ class ASRModel(PreTrainedModel):
706
 
707
  return generated_ids[:, prompt_length:]
708
 
709
- @torch.no_grad()
710
- def generate_stream(
711
- self,
712
- input_values: Optional[torch.Tensor] = None,
713
- input_features: Optional[torch.Tensor] = None,
714
- system_prompt: Optional[str] = None,
715
- user_prompt: Optional[str] = None,
716
- task: Optional[str] = None,
717
- max_new_tokens: Optional[int] = None,
718
- temperature: Optional[float] = None,
719
- **generate_kwargs,
720
- ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
721
- """
722
- Generate transcription in streaming mode, yielding text chunks as they're generated.
723
-
724
- Args:
725
- input_values: Audio input tensor for non-Whisper models
726
- input_features: Audio input tensor for Whisper models
727
- system_prompt: System prompt override
728
- user_prompt: User prompt override
729
- task: Task type (transcribe, describe, emotion, continue)
730
- max_new_tokens: Maximum tokens to generate
731
- temperature: Sampling temperature
732
- **generate_kwargs: Additional generation parameters
733
-
734
- Yields:
735
- StreamChunk: Text chunks as they're generated
736
- StreamStats: Final statistics (input_tokens, output_tokens)
737
- """
738
- audio_inputs = input_values if input_values is not None else input_features
739
- if audio_inputs is None:
740
- raise ValueError("input_values or input_features must be provided for generation")
741
-
742
- # Encode audio once and prepare prompt
743
- audio_embeds = self._encode_audio(audio_inputs)
744
- batch_size = audio_embeds.shape[0]
745
- device = audio_embeds.device
746
-
747
- if batch_size > 1:
748
- raise ValueError("Streaming generation only supports batch_size=1")
749
-
750
- if system_prompt is None:
751
- system_prompt = self.system_prompt
752
-
753
- if user_prompt is None:
754
- user_prompt = (
755
- self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
756
- or "Transcribe: <audio>"
757
- )
758
-
759
- messages = []
760
- if system_prompt:
761
- messages.append({"role": "system", "content": system_prompt})
762
- messages.append({"role": "user", "content": user_prompt})
763
-
764
- prompt_ids = self.tokenizer.apply_chat_template(
765
- messages,
766
- tokenize=True,
767
- add_generation_prompt=True,
768
- return_tensors="pt",
769
- enable_thinking=False,
770
- ).to(device)
771
-
772
- if len(prompt_ids.shape) == 1:
773
- prompt_ids = prompt_ids.unsqueeze(0)
774
-
775
- if not (prompt_ids == self.audio_token_id).any():
776
- raise ValueError("Audio token <audio> not found in prompt")
777
-
778
- num_audio_tokens = audio_embeds.shape[1]
779
- expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
780
- inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
781
- input_token_count = expanded_prompt_ids.shape[1]
782
-
783
- attention_mask = torch.ones(
784
- batch_size, input_token_count, dtype=torch.long, device=device
785
- )
786
-
787
- # Set up generation parameters
788
- if max_new_tokens is None:
789
- max_new_tokens = getattr(self.config, "max_new_tokens", 256)
790
-
791
- generate_kwargs.setdefault("max_new_tokens", max_new_tokens)
792
- generate_kwargs.setdefault("use_cache", True)
793
- generate_kwargs.setdefault(
794
- "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
795
- )
796
- generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
797
-
798
- if temperature is not None:
799
- generate_kwargs["temperature"] = temperature
800
- generate_kwargs.setdefault("do_sample", True)
801
-
802
- # Set up the streamer
803
- streamer = TextIteratorStreamer(
804
- self.tokenizer,
805
- skip_prompt=True,
806
- skip_special_tokens=True
807
- )
808
-
809
- # Generate in a separate thread
810
- def generation_thread(future: futures.Future):
811
- try:
812
- result = self.decoder.generate(
813
- input_ids=expanded_prompt_ids,
814
- inputs_embeds=inputs_embeds,
815
- attention_mask=attention_mask,
816
- streamer=streamer,
817
- **generate_kwargs,
818
- )
819
- future.set_result(result)
820
- except Exception as e:
821
- future.set_exception(e)
822
-
823
- future: futures.Future[torch.Tensor] = futures.Future()
824
- thread = threading.Thread(target=generation_thread, args=(future,))
825
- thread.start()
826
-
827
- # Stream the output
828
- output_text = ""
829
- output_token_count = 0
830
-
831
- try:
832
- for chunk in streamer:
833
- if chunk:
834
- output_text += chunk
835
- output_token_count += 1
836
- yield StreamChunk(chunk)
837
- finally:
838
- # Wait for generation to complete
839
- thread.join()
840
-
841
- # Check if there was an exception
842
- if future.exception():
843
- raise future.exception()
844
-
845
- # Yield final statistics
846
- yield StreamStats(input_token_count, output_token_count)
847
-
848
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
849
  import shutil
850
  from pathlib import Path as PathlibPath
 
1
  from pathlib import Path
2
+ from typing import Optional, Union
 
 
 
3
 
4
  import torch
5
  import torch.nn as nn
 
11
  AutoTokenizer,
12
  PreTrainedModel,
13
  Wav2Vec2FeatureExtractor,
 
14
  )
15
  from transformers.generation.utils import (
16
  GenerateBeamDecoderOnlyOutput,
 
25
  from asr_config import ASRConfig # type: ignore[no-redef]
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  class SwiGLU(nn.Module):
29
  def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
30
  super().__init__()
 
118
  return WhisperFeatureExtractor.from_pretrained(
119
  audio_model_id,
120
  feature_size=num_mel_bins,
 
121
  )
122
+ return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
 
 
 
123
 
124
  @classmethod
125
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
 
687
 
688
  return generated_ids[:, prompt_length:]
689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
691
  import shutil
692
  from pathlib import Path as PathlibPath
asr_pipeline.py CHANGED
@@ -1,13 +1,13 @@
1
- from typing import Any, Dict, Generator, Union
2
 
3
  import torch
4
  import transformers
5
  from truecase import get_true_case
6
 
7
  try:
8
- from .asr_modeling import ASRModel, StreamChunk, StreamStats
9
  except ImportError:
10
- from asr_modeling import ASRModel, StreamChunk, StreamStats # type: ignore[no-redef]
11
 
12
 
13
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
@@ -31,11 +31,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
31
  self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
32
 
33
  def __call__(self, inputs, **kwargs):
34
- # Check if streaming is requested
35
- stream = kwargs.pop("stream", False)
36
- if stream:
37
- return self._stream_inference(inputs, **kwargs)
38
-
39
  generate_kwargs = {}
40
  for key in [
41
  "max_new_tokens",
@@ -297,111 +292,3 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
297
  text = get_true_case(text)
298
 
299
  return {"text": text}
300
-
301
- def _stream_inference(
302
- self, inputs, **kwargs
303
- ) -> Generator[Union[Dict[str, str], Dict[str, int]], None, None]:
304
- """
305
- Perform streaming inference on audio input.
306
-
307
- Args:
308
- inputs: Audio input (same format as __call__)
309
- **kwargs: Generation parameters
310
-
311
- Yields:
312
- Dict with "text" key containing text chunks as they're generated,
313
- followed by a final dict with "input_tokens" and "output_tokens" statistics
314
- """
315
- # Extract generation kwargs
316
- generate_kwargs = {}
317
- for key in [
318
- "max_new_tokens",
319
- "temperature",
320
- "do_sample",
321
- "top_k",
322
- "top_p",
323
- "user_prompt",
324
- "task",
325
- "system_prompt",
326
- ]:
327
- if key in kwargs:
328
- generate_kwargs[key] = kwargs.pop(key)
329
-
330
- # Disable chunking for streaming - we want the whole audio at once
331
- kwargs.pop("chunk_length_s", None)
332
- kwargs.pop("stride_length_s", None)
333
-
334
- # Preprocess audio to get model inputs
335
- model_inputs = self.preprocess(inputs, chunk_length_s=0, **kwargs)
336
-
337
- # Handle different input formats
338
- audio_inputs = None
339
- is_whisper = False
340
-
341
- # Check if preprocess returned an iterator (shouldn't with chunk_length_s=0)
342
- from collections.abc import Iterator
343
- if isinstance(model_inputs, Iterator):
344
- # Get the first (and should be only) chunk
345
- try:
346
- model_inputs = next(model_inputs)
347
- except StopIteration:
348
- raise ValueError("Preprocess returned empty iterator")
349
-
350
- if isinstance(model_inputs, torch.Tensor):
351
- audio_inputs = model_inputs
352
- elif isinstance(model_inputs, dict):
353
- # Remove metadata fields
354
- model_inputs.pop("is_last", None)
355
- model_inputs.pop("stride", None)
356
-
357
- # Get audio input (Whisper uses input_features, others use input_values)
358
- if "input_features" in model_inputs:
359
- audio_inputs = model_inputs["input_features"]
360
- is_whisper = True
361
- else:
362
- audio_inputs = model_inputs.get("input_values")
363
-
364
- if audio_inputs is None:
365
- # Debug info
366
- import sys
367
- print(f"DEBUG: model_inputs type: {type(model_inputs)}", file=sys.stderr)
368
- if isinstance(model_inputs, dict):
369
- print(f"DEBUG: model_inputs keys: {model_inputs.keys()}", file=sys.stderr)
370
- raise ValueError(f"Could not extract audio inputs from preprocessing. Got type: {type(model_inputs)}")
371
-
372
- if isinstance(audio_inputs, torch.Tensor):
373
- audio_inputs = audio_inputs.to(self.model.device)
374
- else:
375
- raise ValueError(f"audio inputs must be a tensor, got {type(audio_inputs)}")
376
-
377
- # Call the streaming generate method
378
- if is_whisper:
379
- stream_generator = self.model.generate_stream(
380
- input_features=audio_inputs,
381
- **generate_kwargs,
382
- )
383
- else:
384
- stream_generator = self.model.generate_stream(
385
- input_values=audio_inputs,
386
- **generate_kwargs,
387
- )
388
-
389
- # Track full text for post-processing
390
- full_text = ""
391
-
392
- # Stream the chunks
393
- for item in stream_generator:
394
- if isinstance(item, StreamChunk):
395
- full_text += item.text
396
- yield {"text": item.text}
397
- elif isinstance(item, StreamStats):
398
- # Apply post-processing to the full text
399
- processed_text = self.text_normalizer.normalize(full_text)
400
- processed_text = get_true_case(processed_text)
401
-
402
- # Yield final statistics with processed text
403
- yield {
404
- "input_tokens": item.input_tokens,
405
- "output_tokens": item.output_tokens,
406
- "full_text": processed_text,
407
- }
 
1
+ from typing import Any, Dict
2
 
3
  import torch
4
  import transformers
5
  from truecase import get_true_case
6
 
7
  try:
8
+ from .asr_modeling import ASRModel
9
  except ImportError:
10
+ from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
31
  self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
32
 
33
  def __call__(self, inputs, **kwargs):
 
 
 
 
 
34
  generate_kwargs = {}
35
  for key in [
36
  "max_new_tokens",
 
292
  text = get_true_case(text)
293
 
294
  return {"text": text}