File size: 17,038 Bytes
128578f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 |
"""Inference-only MERaLiON AudioLLM model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
import librosa
import numpy as np
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.utils import maybe_prefix
from .modeling_meralion import MERaLiONSpeechEncoder
logger = init_logger(__name__)
# gemma2 ties word embedding by default
_KEYS_TO_MODIFY_MAPPING = {
"text_decoder.model": "text_decoder",
}
# === Audio Inputs === #
class MERaLiONInputs(TypedDict):
input_features: torch.Tensor
"""Shape:
`(num_audios, num_mel_bins, 3000)`
"""
feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)`
"""
# === Audio Encoder === #
class MERaLiONSpeechAudioAdaper(nn.Module):
def __init__(self, audio_hidden_size: int, text_hidden_size: int):
super(MERaLiONSpeechAudioAdaper, self).__init__()
speech_mlp_scale_factor = 15
self.speech_mlp_scale_factor = speech_mlp_scale_factor
self.mlp_adapter = nn.Sequential(
nn.Linear(
in_features=audio_hidden_size * speech_mlp_scale_factor,
out_features=audio_hidden_size
),
nn.SiLU(),
nn.Dropout(0.1),
)
self.speech_llm_proj = nn.Sequential(
nn.Linear(
audio_hidden_size,
audio_hidden_size * 4
),
nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(
audio_hidden_size * 4,
text_hidden_size
),
)
def forward(self, speech_embeds, **kwargs):
B, T, C = speech_embeds.shape
speech_embeds = self.mlp_adapter(
speech_embeds.reshape(
B,
T // self.speech_mlp_scale_factor,
C * self.speech_mlp_scale_factor,
)
)
return self.speech_llm_proj(speech_embeds)
def dummy_data_for_meralion(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"]
max_tokens_per_audio = get_max_meralion_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
"please increase max_model_len or reduce audio limit by "
"--limit-mm-per-prompt.")
speech_token_index = ctx.model_config.hf_config.speech_token_index
dummy_seqdata = SequenceData.from_prompt_token_counts(
(speech_token_index, max_llm_audio_tokens),
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = True,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace.
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
"""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_processor = lru_cache(get_processor)
def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
"""
The max number of tokens after speech audio adapter.
"""
return 100
def input_processor_for_meralion(
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
if len(audios) == 0:
return inputs
processor = cached_get_processor(ctx.model_config.model)
resampled_audios = [
librosa.resample(audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in audios
]
audio_input_lengths = np.array(
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
audio_output_length = get_max_meralion_audio_tokens(ctx)
speech_token_index = ctx.model_config.hf_config.speech_token_index
input_ids = inputs['prompt_token_ids']
new_input_ids = []
audio_num = input_ids.count(speech_token_index)
assert len(audio_input_lengths) == audio_num, \
(f'The text input contains {audio_num} audio tokens, '
f'but {len(audio_input_lengths)} audios provided')
start = 0
for _ in range(audio_num):
end = input_ids.index(speech_token_index, start)
new_input_ids.extend(input_ids[start:end]) # text part
new_input_ids.extend([speech_token_index] * audio_output_length)
start = end + 1
new_input_ids.extend(input_ids[start:])
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=inputs['prompt'],
multi_modal_data=multi_modal_data,
)
def input_mapper_for_meralion(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
"""Input mapper for Qwen2-Audio."""
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_get_processor(ctx.model_config.model)
audio_feature_extractor = processor.feature_extractor
if audio_feature_extractor is None:
raise RuntimeError(
"No HuggingFace audio_feature_extractor is available "
"to process the audio object")
try:
resampled_audios = [
librosa.resample(
audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in multi_modal_data
]
batch_data = audio_feature_extractor(resampled_audios,
sampling_rate=16000,
return_attention_mask=True,
padding="max_length",
return_tensors="pt").data
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
except Exception:
logger.error("Failed to process audio (%s)", multi_modal_data)
raise
return MultiModalKwargs(batch_data)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_meralion)
@INPUT_REGISTRY.register_input_processor(input_processor_for_meralion)
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
input_mapper_for_meralion)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_meralion_audio_tokens)
class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.speech_encoder = MERaLiONSpeechEncoder(config.speech_config)
self.ln_speech = nn.LayerNorm(config.speech_config.d_model)
self.speech_audio_adapter = MERaLiONSpeechAudioAdaper(
config.speech_config.d_model, config.text_config.hidden_size)
self.quant_config = quant_config
self.text_decoder = Gemma2Model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.text_decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.text_decoder.make_empty_intermediate_tensors)
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[MERaLiONInputs]:
input_features = kwargs.pop('input_features', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
if input_features is None:
return None
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
return MERaLiONInputs(input_features=input_features,
feature_attention_mask=feature_attention_mask)
def _process_audio_input(self,
audio_input: MERaLiONInputs) -> torch.Tensor:
input_features = audio_input["input_features"].to(self.speech_encoder.dtype)
feature_attention_mask = audio_input["feature_attention_mask"]
audio_outputs = self.speech_encoder(input_features,
attention_mask=feature_attention_mask)
audio_features = audio_outputs.last_hidden_state
audio_features = self.ln_speech(audio_features)
audio_features = self.speech_audio_adapter(audio_features)
audio_features = audio_features.view(-1, audio_features.size(-1))
return audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
inputs_embeds = None
else:
inputs_embeds = self.text_decoder.embed_tokens(input_ids)
processed_audio_features = self._process_audio_input(audio_input)
# merge llm embeddings and audio features
mask = (input_ids == self.config.speech_token_index)
inputs_embeds[mask, :] = processed_audio_features
input_ids = None
hidden_states = self.text_decoder(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name or 'speech_encoder' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
|