farzadab commited on
Commit
39585b0
1 Parent(s): 9ef25f3

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. ultravox_config.py +141 -0
  3. ultravox_model.py +404 -0
config.json CHANGED
@@ -15,6 +15,10 @@
15
  },
16
  "audio_model_id": "facebook/wav2vec2-base-960h",
17
  "audio_token_index": 32000,
 
 
 
 
18
  "custom_pipelines": {
19
  "ultravox-pipeline": {
20
  "default": {
 
15
  },
16
  "audio_model_id": "facebook/wav2vec2-base-960h",
17
  "audio_token_index": 32000,
18
+ "auto_map": {
19
+ "AutoConfig": "ultravox_config.UltravoxConfig",
20
+ "AutoModel": "ultravox_model.UltravoxModel"
21
+ },
22
  "custom_pipelines": {
23
  "ultravox-pipeline": {
24
  "default": {
ultravox_config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import transformers
5
+
6
+
7
+ @dataclasses.dataclass
8
+ class LoraConfigSimplified:
9
+ """
10
+ Low Rank Approximation (LoRA) configuration.
11
+
12
+ Used for language and audio models separately.
13
+ """
14
+
15
+ # The rank of the approximation
16
+ r: int = 0
17
+ lora_alpha: float = 8
18
+ target_modules: Optional[List[str]] = dataclasses.field(
19
+ default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
20
+ )
21
+
22
+
23
+ class UltravoxConfig(transformers.PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`UltravoxForConditionalGeneration`]. It is used to instantiate an
26
+ Ultravox model according to the specified arguments, defining the model architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ audio_config (`Wav2Vec2Config`, *optional*):
33
+ Custom audio config or dict
34
+ text_config (`Union[AutoConfig, dict]`, *optional*):
35
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
36
+ ignore_index (`int`, *optional*, defaults to -100):
37
+ The ignore index for the loss function.
38
+ audio_token_index (`int`, *optional*, defaults to 32000):
39
+ The audio token index to encode the audio prompt.
40
+ stack_factor (`int`, *optional*, defaults to 8):
41
+ Audio downsampling factor for the multimodal projector.
42
+ norm_init (`float`, *optional*, defaults to 0.4):
43
+ The initialization value for the layer normalization.
44
+ projector_act (`str`, *optional*, defaults to `"swiglu"`):
45
+ The activation function used by the multimodal projector.
46
+ text_model_lora_config (`LoraConfigSimplified`, *optional*):
47
+ The LoRA configuration for finetuning the text model.
48
+ audio_model_lora_config (`LoraConfigSimplified`, *optional*):
49
+ The LoRA configuration for finetuning the audio model.
50
+
51
+
52
+ Example:
53
+
54
+ ```python
55
+ >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
56
+
57
+ >>> # Initializing an audio encoder config
58
+ >>> audio_config = Wav2Vec2Config()
59
+
60
+ >>> # Initializing a Llama config
61
+ >>> text_config = LlamaConfig()
62
+
63
+ >>> # Initializing a default configuration
64
+ >>> configuration = UltravoxConfig(audio_config, text_config)
65
+
66
+ >>> # Initializing a completely untrained model from the configuration
67
+ >>> model = UltravoxForConditionalGeneration(configuration)
68
+
69
+ >>> # Accessing the model configuration
70
+ >>> configuration = model.config
71
+
72
+ >>> # Initialize a model from pretrained checkpoints and random projector weights
73
+ >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
74
+ ```"""
75
+
76
+ model_type = "ultravox"
77
+ is_composition = False
78
+
79
+ def __init__(
80
+ self,
81
+ audio_config: Optional[Dict[str, Any]] = None,
82
+ text_config: Optional[Dict[str, Any]] = None,
83
+ audio_model_id: Optional[str] = None,
84
+ text_model_id: Optional[str] = None,
85
+ ignore_index: int = -100,
86
+ audio_token_index: int = 32000,
87
+ hidden_size: int = 4096,
88
+ stack_factor: int = 8,
89
+ norm_init: float = 0.4,
90
+ projector_act: str = "swiglu",
91
+ text_model_lora_config: Optional[LoraConfigSimplified] = None,
92
+ audio_model_lora_config: Optional[LoraConfigSimplified] = None,
93
+ **kwargs,
94
+ ):
95
+ self.ignore_index = ignore_index
96
+
97
+ self.audio_model_id = audio_model_id
98
+ self.text_model_id = text_model_id
99
+ self.audio_token_index = audio_token_index
100
+
101
+ self.hidden_size = hidden_size
102
+ self.stack_factor = stack_factor
103
+ self.norm_init = norm_init
104
+ self.projector_act = projector_act
105
+
106
+ if text_model_id is not None:
107
+ self.text_config: transformers.LlamaConfig = (
108
+ transformers.AutoConfig.from_pretrained(text_model_id)
109
+ )
110
+ else:
111
+ text_config = text_config or {}
112
+ self.text_config = transformers.CONFIG_MAPPING[
113
+ text_config.get("model_type", "llama")
114
+ ](**text_config)
115
+
116
+ if audio_model_id is not None:
117
+ self.audio_config: transformers.PretrainedConfig = (
118
+ transformers.AutoConfig.from_pretrained(audio_model_id)
119
+ )
120
+ else:
121
+ audio_config = audio_config or {}
122
+ self.audio_config = transformers.CONFIG_MAPPING[
123
+ audio_config.get("model_type", "wav2vec2")
124
+ ](**audio_config)
125
+
126
+ self.text_model_lora_config = (
127
+ text_model_lora_config
128
+ if isinstance(text_model_lora_config, dict)
129
+ else dataclasses.asdict(text_model_lora_config or LoraConfigSimplified())
130
+ )
131
+ self.audio_model_lora_config = (
132
+ audio_model_lora_config
133
+ if isinstance(audio_model_lora_config, dict)
134
+ else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
135
+ )
136
+
137
+ self.vocab_size = self.text_config.vocab_size
138
+
139
+ self.initializer_range = self.text_config.initializer_range
140
+
141
+ super().__init__(**kwargs)
ultravox_model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Optional, Set, Tuple, Union
3
+
4
+ import peft
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import transformers
9
+ import transformers.activations
10
+ import transformers.modeling_outputs
11
+ import transformers.models
12
+
13
+ # We must use relative import in this directory to allow uploading to HF Hub
14
+ from . import ultravox_config
15
+ from . import whisper_model_modified
16
+
17
+
18
+ class UltravoxModel(
19
+ transformers.LlamaPreTrainedModel,
20
+ transformers.GenerationMixin,
21
+ ):
22
+ """
23
+ The Ultravox model which consists of an audio encoder and a language model.
24
+
25
+ Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
26
+ projected to the language model's embedding space using a few linear layers.
27
+ The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
28
+
29
+ A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
30
+
31
+ Parameters:
32
+ config: Model configuration class with all the parameters of the model.
33
+ """
34
+
35
+ config_class = ultravox_config.UltravoxConfig
36
+ config: ultravox_config.UltravoxConfig # for type hinting
37
+ _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
38
+
39
+ def __init__(self, config: ultravox_config.UltravoxConfig):
40
+ super().__init__(config)
41
+
42
+ self.keep_params: Set[str] = set()
43
+ self.vocab_size = config.vocab_size
44
+
45
+ self.audio_tower = self._create_audio_tower(config)
46
+ self.multi_modal_projector = UltravoxProjector(config)
47
+ self.language_model = self._create_language_model(config)
48
+
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self):
52
+ return self.language_model.get_input_embeddings()
53
+
54
+ def set_input_embeddings(self, value):
55
+ self.language_model.set_input_embeddings(value)
56
+
57
+ def get_output_embeddings(self):
58
+ return self.language_model.get_output_embeddings()
59
+
60
+ def set_output_embeddings(self, new_embeddings):
61
+ self.language_model.set_output_embeddings(new_embeddings)
62
+
63
+ def set_decoder(self, decoder):
64
+ self.language_model.set_decoder(decoder)
65
+
66
+ def get_decoder(self):
67
+ return self.language_model.get_decoder()
68
+
69
+ def tie_weights(self):
70
+ return self.language_model.tie_weights()
71
+
72
+ def _setup_cache(
73
+ self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
74
+ ):
75
+ self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
76
+
77
+ def _reorder_cache(self, past_key_values, beam_idx):
78
+ return self.language_model._reorder_cache(past_key_values, beam_idx)
79
+
80
+ def resize_token_embeddings(
81
+ self,
82
+ new_num_tokens: Optional[int] = None,
83
+ pad_to_multiple_of: Optional[int] = None,
84
+ ) -> nn.Embedding:
85
+ model_embeds = self.language_model.resize_token_embeddings(
86
+ new_num_tokens, pad_to_multiple_of
87
+ )
88
+ # update vocab size
89
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
90
+ self.config.vocab_size = model_embeds.num_embeddings
91
+ self.vocab_size = model_embeds.num_embeddings
92
+ return model_embeds
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ audio_values: Optional[torch.FloatTensor] = None,
98
+ inputs_embeds: Optional[torch.FloatTensor] = None,
99
+ labels: Optional[torch.Tensor] = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ audio_token_start_idx: Optional[torch.Tensor] = None,
102
+ audio_token_len: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple] = None,
104
+ **kwargs,
105
+ ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
106
+ """
107
+ Forward pass for the Ultravox model.
108
+
109
+ `input_ids` are the tokenized text input. They are embedded by the language model as usual.
110
+ `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
111
+ projected to the language model's embedding space using a few linear layers.
112
+ The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
113
+ of the audio embeddings in the merged embeddings.
114
+
115
+ Args:
116
+ input_ids: The tokenized text input.
117
+ audio_values: The processed audio values.
118
+ inputs_embeds: The embeddings for the input tokens.
119
+ labels: The tokenized text labels.
120
+ attention_mask: The attention mask for the input.
121
+ position_ids: The position ids for the input.
122
+ past_key_values: The past key value cache for the language model attention layers.
123
+ **kwargs: Additional keyword arguments. Passed directly to the language model.
124
+ """
125
+ if inputs_embeds is None:
126
+ # B x T -> B x T x D
127
+ inputs_embeds = self.get_input_embeddings().forward(input_ids)
128
+
129
+ if audio_values is not None:
130
+ assert (
131
+ audio_token_start_idx is not None and audio_token_len is not None
132
+ ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
133
+ assert (
134
+ len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
135
+ ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
136
+
137
+ # B x A/3200 x D
138
+ audio_tower_output = self.audio_tower.forward(
139
+ audio_values
140
+ ).last_hidden_state
141
+ audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
142
+
143
+ audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
144
+
145
+ # combine audio and text embeddings
146
+ for i, (audio, start, length) in enumerate(
147
+ zip(audio_embeds, audio_token_start_idx, audio_token_len)
148
+ ):
149
+ length = min(length, audio.shape[0])
150
+ inputs_embeds[i, start : start + length] = audio[:length]
151
+
152
+ lm_output = self.language_model.forward(
153
+ inputs_embeds=inputs_embeds,
154
+ labels=labels,
155
+ attention_mask=attention_mask,
156
+ past_key_values=past_key_values,
157
+ **kwargs,
158
+ )
159
+
160
+ return lm_output
161
+
162
+ def prepare_inputs_for_generation(
163
+ self,
164
+ input_ids: torch.Tensor,
165
+ audio_values: Optional[torch.FloatTensor] = None,
166
+ audio_token_start_idx: Optional[torch.Tensor] = None,
167
+ audio_token_len: Optional[torch.Tensor] = None,
168
+ past_key_values: Optional[Tuple] = None,
169
+ attention_mask: Optional[torch.Tensor] = None,
170
+ inputs_embeds: Optional[torch.Tensor] = None,
171
+ **kwargs,
172
+ ) -> Dict[str, Any]:
173
+ model_input = self.language_model.prepare_inputs_for_generation(
174
+ input_ids=input_ids,
175
+ past_key_values=past_key_values,
176
+ attention_mask=attention_mask,
177
+ inputs_embeds=inputs_embeds,
178
+ **kwargs,
179
+ )
180
+
181
+ if past_key_values is None and audio_values is not None:
182
+ # We only want to use audio features in the 1st generation step
183
+ model_input["audio_values"] = audio_values
184
+ model_input["audio_token_start_idx"] = audio_token_start_idx
185
+ model_input["audio_token_len"] = audio_token_len
186
+
187
+ return model_input
188
+
189
+ @classmethod
190
+ def _create_audio_tower(cls, config: ultravox_config.UltravoxConfig) -> Union[
191
+ transformers.Wav2Vec2Model,
192
+ transformers.models.whisper.modeling_whisper.WhisperEncoder,
193
+ ]:
194
+ if config.audio_model_id is not None:
195
+ if "whisper" in config.audio_model_id is not None:
196
+ audio_tower = whisper_model_modified.WhisperEncoder.from_pretrained(
197
+ config.audio_model_id
198
+ )
199
+ else:
200
+ audio_tower = transformers.AutoModel.from_pretrained(
201
+ config.audio_model_id
202
+ )
203
+ else:
204
+ if "whisper" in config.audio_config._name_or_path:
205
+ audio_tower = whisper_model_modified.WhisperEncoder(config.audio_config)
206
+ else:
207
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
208
+
209
+ if isinstance(
210
+ audio_tower,
211
+ (transformers.Wav2Vec2BertModel, transformers.WhisperModel),
212
+ ):
213
+ # For these models we only need the encoder part
214
+ # Wav2Vec2BertModel -> Wav2Vec2BertEncoder
215
+ # WhisperModel -> WhisperEncoder
216
+ audio_tower = audio_tower.encoder
217
+
218
+ audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
219
+ return audio_tower
220
+
221
+ @classmethod
222
+ def _create_language_model(
223
+ cls, config: ultravox_config.UltravoxConfig
224
+ ) -> transformers.LlamaForCausalLM:
225
+ if config.text_model_id is not None:
226
+ language_model = transformers.AutoModelForCausalLM.from_pretrained(
227
+ config.text_model_id, attn_implementation=config._attn_implementation
228
+ )
229
+ else:
230
+ language_model = transformers.AutoModelForCausalLM.from_config(
231
+ config.text_config, attn_implementation=config._attn_implementation
232
+ )
233
+
234
+ language_model = apply_lora(language_model, config.text_model_lora_config)
235
+ return language_model
236
+
237
+ def merge_and_unload(self):
238
+ if isinstance(self.language_model, peft.PeftModel):
239
+ self.language_model = self.language_model.merge_and_unload()
240
+ # no need to download base language model weights anymore, so we can remove the id
241
+ self.config.text_model_id = None
242
+ self.keep_params.update(
243
+ set(
244
+ [
245
+ f"language_model.{name}"
246
+ for name, _ in self.language_model.named_parameters()
247
+ ]
248
+ )
249
+ )
250
+
251
+ if isinstance(self.audio_tower, peft.PeftModel):
252
+ self.audio_tower = self.audio_tower.merge_and_unload()
253
+ # no need to download base audio model weights anymore, so we can remove the id
254
+ self.config.audio_model_id = None
255
+ self.keep_params.update(
256
+ set(
257
+ [
258
+ f"audio_tower.{name}"
259
+ for name, _ in self.audio_tower.named_parameters()
260
+ ]
261
+ )
262
+ )
263
+
264
+ for param in ["text_model_lora_config", "audio_model_lora_config"]:
265
+ if hasattr(self.config, param):
266
+ delattr(self.config, param)
267
+
268
+ def push_to_hub(self, *args, **kwargs):
269
+ self.merge_and_unload()
270
+ self.to(self.language_model.dtype)
271
+ return super().push_to_hub(*args, **kwargs)
272
+
273
+ def state_dict(self, *args, **kwargs):
274
+ named_params = dict(self.named_parameters())
275
+ state_dict = super().state_dict(*args, **kwargs)
276
+
277
+ state_dict = {
278
+ k: v
279
+ for k, v in state_dict.items()
280
+ if k in self.keep_params
281
+ or (k in named_params and named_params[k].requires_grad)
282
+ }
283
+ return state_dict
284
+
285
+ def load_state_dict(
286
+ self,
287
+ state_dict: Dict[str, Any],
288
+ *args,
289
+ **kwargs,
290
+ ):
291
+ self.keep_params.update(set(state_dict.keys()))
292
+ return super().load_state_dict(state_dict, *args, **kwargs)
293
+
294
+ def print_trainable_parameters(self):
295
+ """
296
+ Prints the number of trainable parameters in the model (reuses Peft model's method)
297
+ """
298
+ count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
299
+
300
+ trainable_params, all_param = count_params(self)
301
+
302
+ logging.info(
303
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
304
+ f" || trainable%: {100 * trainable_params / all_param:.1f}%"
305
+ )
306
+
307
+ lm_trainable_params, lm_all_params = count_params(self.language_model)
308
+ audio_trainable_params, audio_all_params = count_params(self.audio_tower)
309
+
310
+ projector_trainable_params = (
311
+ trainable_params - lm_trainable_params - audio_trainable_params
312
+ )
313
+ projector_all_params = all_param - lm_all_params - audio_all_params
314
+
315
+ logging.info(
316
+ f"Trainable%: "
317
+ f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
318
+ f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
319
+ f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
320
+ )
321
+
322
+
323
+ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
324
+ """
325
+ Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
326
+ """
327
+ lora_config = peft.LoraConfig(**lora_config or {})
328
+
329
+ if lora_config.r == 0:
330
+ # freeze the model entirely
331
+ for param in model.parameters():
332
+ param.requires_grad = False
333
+ else:
334
+ model = peft.get_peft_model(model, lora_config)
335
+
336
+ return model
337
+
338
+
339
+ class StackAudioFrames(nn.Module):
340
+ """
341
+ Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
342
+
343
+ The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
344
+ NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
345
+ we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
346
+ In most cases this extra padding will get removed in the model's forward function so it has no effect.
347
+ """
348
+
349
+ def __init__(self, stack_factor: int = 8):
350
+ super().__init__()
351
+ self.stack_factor = stack_factor
352
+
353
+ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
354
+ B, T, C = audio_embeds.shape
355
+ T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
356
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
357
+ B, T, C = audio_embeds.shape
358
+ audio_embeds = audio_embeds.view(
359
+ B, T // self.stack_factor, C * self.stack_factor
360
+ )
361
+ return audio_embeds
362
+
363
+
364
+ class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
365
+ def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
366
+ super().__init__(hidden_size=hidden_size, eps=eps)
367
+ self.weight.data.fill_(init)
368
+
369
+
370
+ class SwiGLU(nn.Module):
371
+ def forward(self, x):
372
+ x, gate = x.chunk(2, dim=-1)
373
+ return F.silu(gate) * x
374
+
375
+
376
+ class UltravoxProjector(nn.Sequential):
377
+ def __init__(self, config: ultravox_config.UltravoxConfig):
378
+ super().__init__()
379
+ self.hidden_dim = config.hidden_size
380
+ self._pad_and_stack = StackAudioFrames(config.stack_factor)
381
+ dim = config.audio_config.hidden_size * config.stack_factor
382
+ self.ln_pre = RMSNorm(dim, init=config.norm_init)
383
+ self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
384
+ dim = self.hidden_dim
385
+ self.act = transformers.activations.get_activation(config.projector_act)
386
+ dim = dim // 2 if config.projector_act == "swiglu" else dim
387
+ self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
388
+ self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
389
+
390
+ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
391
+ audio_features = self._pad_and_stack(audio_features)
392
+ audio_features = self.ln_pre(audio_features)
393
+ hidden_states = self.linear_1(audio_features)
394
+ hidden_states = self.act(hidden_states)
395
+ hidden_states = self.linear_2(hidden_states)
396
+ hidden_states = self.ln_post(hidden_states)
397
+ return hidden_states
398
+
399
+
400
+ transformers.AutoModelForCausalLM.register(
401
+ ultravox_config.UltravoxConfig, UltravoxModel
402
+ )
403
+
404
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU