farzadab commited on
Commit
aed30da
1 Parent(s): e3e0769

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +29 -1
ultravox_processing.py CHANGED
@@ -1,9 +1,11 @@
1
- from typing import Optional, Union
2
 
3
  import numpy as np
4
  import torch
5
  import transformers
6
 
 
 
7
 
8
  class UltravoxProcessor(transformers.ProcessorMixin):
9
  """
@@ -56,6 +58,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
56
  ), "The tokenizer has no EOS token. Cannot recover."
57
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def __call__(
60
  self,
61
  text: Optional[str] = None,
@@ -175,3 +200,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
175
  tokenizer_input_names = self.tokenizer.model_input_names
176
  audio_processor_input_names = self.audio_processor.model_input_names
177
  return list(set(tokenizer_input_names + audio_processor_input_names))
 
 
 
 
1
+ from typing import Optional, Union, Dict, Any
2
 
3
  import numpy as np
4
  import torch
5
  import transformers
6
 
7
+ from .ultravox_config import UltravoxConfig
8
+
9
 
10
  class UltravoxProcessor(transformers.ProcessorMixin):
11
  """
 
58
  ), "The tokenizer has no EOS token. Cannot recover."
59
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
60
 
61
+ @classmethod
62
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
63
+ config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
64
+ pretrained_model_name_or_path, **kwargs
65
+ )
66
+ audio_processor = transformers.AutoProcessor.from_pretrained(
67
+ config.audio_model_id
68
+ or config.audio_config._name_or_path
69
+ or "facebook/wav2vec2-base-960h"
70
+ )
71
+
72
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
73
+ pretrained_model_name_or_path, **kwargs
74
+ )
75
+ tokenizer.padding_side = "left"
76
+ tokenizer.pad_token = tokenizer.eos_token
77
+
78
+ return cls(
79
+ audio_processor=audio_processor,
80
+ tokenizer=tokenizer,
81
+ stack_factor=config.stack_factor,
82
+ )
83
+
84
  def __call__(
85
  self,
86
  text: Optional[str] = None,
 
200
  tokenizer_input_names = self.tokenizer.model_input_names
201
  audio_processor_input_names = self.audio_processor.model_input_names
202
  return list(set(tokenizer_input_names + audio_processor_input_names))
203
+
204
+
205
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)