Update ultravox_processing.py
Browse files- ultravox_processing.py +28 -0
ultravox_processing.py
CHANGED
@@ -4,6 +4,8 @@ import numpy as np
|
|
4 |
import torch
|
5 |
import transformers
|
6 |
|
|
|
|
|
7 |
|
8 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
9 |
"""
|
@@ -59,6 +61,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
59 |
|
60 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def __call__(
|
63 |
self,
|
64 |
text: Optional[str] = None,
|
@@ -178,3 +203,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
178 |
tokenizer_input_names = self.tokenizer.model_input_names
|
179 |
audio_processor_input_names = self.audio_processor.model_input_names
|
180 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
import transformers
|
6 |
|
7 |
+
from .ultravox_config import UltravoxConfig
|
8 |
+
|
9 |
|
10 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
11 |
"""
|
|
|
61 |
|
62 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
63 |
|
64 |
+
@classmethod
|
65 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
66 |
+
config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
|
67 |
+
pretrained_model_name_or_path, **kwargs
|
68 |
+
)
|
69 |
+
audio_processor = transformers.AutoProcessor.from_pretrained(
|
70 |
+
config.audio_model_id
|
71 |
+
or config.audio_config._name_or_path
|
72 |
+
or "facebook/wav2vec2-base-960h"
|
73 |
+
)
|
74 |
+
|
75 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
76 |
+
pretrained_model_name_or_path, **kwargs
|
77 |
+
)
|
78 |
+
tokenizer.padding_side = "left"
|
79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
80 |
+
|
81 |
+
return cls(
|
82 |
+
audio_processor=audio_processor,
|
83 |
+
tokenizer=tokenizer,
|
84 |
+
stack_factor=config.stack_factor,
|
85 |
+
)
|
86 |
+
|
87 |
def __call__(
|
88 |
self,
|
89 |
text: Optional[str] = None,
|
|
|
203 |
tokenizer_input_names = self.tokenizer.model_input_names
|
204 |
audio_processor_input_names = self.audio_processor.model_input_names
|
205 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
206 |
+
|
207 |
+
|
208 |
+
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
|