Upload folder using huggingface_hub
Browse files- README.md +1 -1
- config.json +1 -6
- configuration_midashenglm.py +5 -11
- model.safetensors.index.json +13 -13
- modeling_midashenglm.py +29 -47
- processing_midashenglm.py +23 -20
    	
        README.md
    CHANGED
    
    | @@ -51,7 +51,7 @@ base_model: | |
| 51 |  | 
| 52 | 
             
            >>> import torch
         | 
| 53 | 
             
            >>> with torch.no_grad():
         | 
| 54 | 
            -
            ...     model_inputs = processor(text=text, audio=audio)
         | 
| 55 | 
             
            ...     generation = model.generate(**model_inputs)
         | 
| 56 | 
             
            ...     output = processor.batch_decode(generation, skip_special_tokens=True)
         | 
| 57 |  | 
|  | |
| 51 |  | 
| 52 | 
             
            >>> import torch
         | 
| 53 | 
             
            >>> with torch.no_grad():
         | 
| 54 | 
            +
            ...     model_inputs = processor(text=text, audio=audio, sampling_rate=sr))
         | 
| 55 | 
             
            ...     generation = model.generate(**model_inputs)
         | 
| 56 | 
             
            ...     output = processor.batch_decode(generation, skip_special_tokens=True)
         | 
| 57 |  | 
    	
        config.json
    CHANGED
    
    | @@ -37,15 +37,10 @@ | |
| 37 | 
             
                "AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
         | 
| 38 | 
             
                "AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
         | 
| 39 | 
             
              },
         | 
| 40 | 
            -
              "freeze": null,
         | 
| 41 | 
            -
              "gradient_checkpoint_decoder": false,
         | 
| 42 | 
            -
              "lora": null,
         | 
| 43 | 
            -
              "model": "DashengQwen25OmniModelInstruct",
         | 
| 44 | 
             
              "model_type": "miaudiollm",
         | 
| 45 | 
             
              "resize_tokenizer": false,
         | 
| 46 | 
             
              "subsample_factor": 5,
         | 
| 47 | 
            -
              " | 
| 48 | 
            -
                "_attn_implementation_autoset": true,
         | 
| 49 | 
             
                "attention_dropout": 0.0,
         | 
| 50 | 
             
                "hidden_act": "silu",
         | 
| 51 | 
             
                "hidden_size": 2048,
         | 
|  | |
| 37 | 
             
                "AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
         | 
| 38 | 
             
                "AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
         | 
| 39 | 
             
              },
         | 
|  | |
|  | |
|  | |
|  | |
| 40 | 
             
              "model_type": "miaudiollm",
         | 
| 41 | 
             
              "resize_tokenizer": false,
         | 
| 42 | 
             
              "subsample_factor": 5,
         | 
| 43 | 
            +
              "text_config": {
         | 
|  | |
| 44 | 
             
                "attention_dropout": 0.0,
         | 
| 45 | 
             
                "hidden_act": "silu",
         | 
| 46 | 
             
                "hidden_size": 2048,
         | 
    	
        configuration_midashenglm.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            from ast import Dict
         | 
| 2 | 
            -
            from typing import  | 
| 3 |  | 
| 4 | 
             
            from transformers import PretrainedConfig
         | 
| 5 | 
             
            from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
         | 
| @@ -66,22 +66,16 @@ class MiAudioLLMHFConfig(PretrainedConfig): | |
| 66 |  | 
| 67 | 
             
                def __init__(
         | 
| 68 | 
             
                    self,
         | 
| 69 | 
            -
                    model: str = "DashengQwen2ModelInstruct",
         | 
| 70 | 
             
                    audio_encoder_config: Dict = {},
         | 
| 71 | 
            -
                    freeze: Literal["audio", "text"] | str | None = None,
         | 
| 72 | 
            -
                    lora: Literal["encoder", "decoder"] | None = None,
         | 
| 73 | 
             
                    subsample_factor: int = 5,
         | 
| 74 | 
            -
                     | 
| 75 | 
             
                    **kwargs,
         | 
| 76 | 
             
                ):
         | 
| 77 | 
            -
                    self.model = model
         | 
| 78 | 
             
                    self.audio_encoder_config = DashengConfig(**audio_encoder_config)
         | 
| 79 | 
            -
                    self.freeze = freeze
         | 
| 80 | 
            -
                    self.lora = lora
         | 
| 81 | 
             
                    self.subsample_factor = subsample_factor
         | 
| 82 | 
            -
                    self. | 
| 83 | 
            -
                        Qwen2_5OmniTextConfig(** | 
| 84 | 
            -
                        if  | 
| 85 | 
             
                        else Qwen2_5OmniTextConfig()
         | 
| 86 | 
             
                    )
         | 
| 87 | 
             
                    super().__init__(**kwargs)
         | 
|  | |
| 1 | 
             
            from ast import Dict
         | 
| 2 | 
            +
            from typing import Tuple, Union
         | 
| 3 |  | 
| 4 | 
             
            from transformers import PretrainedConfig
         | 
| 5 | 
             
            from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
         | 
|  | |
| 66 |  | 
| 67 | 
             
                def __init__(
         | 
| 68 | 
             
                    self,
         | 
|  | |
| 69 | 
             
                    audio_encoder_config: Dict = {},
         | 
|  | |
|  | |
| 70 | 
             
                    subsample_factor: int = 5,
         | 
| 71 | 
            +
                    text_config: Dict = None,
         | 
| 72 | 
             
                    **kwargs,
         | 
| 73 | 
             
                ):
         | 
|  | |
| 74 | 
             
                    self.audio_encoder_config = DashengConfig(**audio_encoder_config)
         | 
|  | |
|  | |
| 75 | 
             
                    self.subsample_factor = subsample_factor
         | 
| 76 | 
            +
                    self.text_config = (
         | 
| 77 | 
            +
                        Qwen2_5OmniTextConfig(**text_config)
         | 
| 78 | 
            +
                        if text_config
         | 
| 79 | 
             
                        else Qwen2_5OmniTextConfig()
         | 
| 80 | 
             
                    )
         | 
| 81 | 
             
                    super().__init__(**kwargs)
         | 
    	
        model.safetensors.index.json
    CHANGED
    
    | @@ -390,20 +390,20 @@ | |
| 390 | 
             
                "audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
         | 
| 391 | 
             
                "audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
         | 
| 392 | 
             
                "audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
         | 
| 393 | 
            -
                "audio_encoder.init_bn. | 
| 394 | 
            -
                "audio_encoder.init_bn. | 
| 395 | 
            -
                "audio_encoder.init_bn. | 
| 396 | 
            -
                "audio_encoder.init_bn. | 
| 397 | 
            -
                "audio_encoder.init_bn. | 
| 398 | 
             
                "audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
         | 
| 399 | 
             
                "audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
         | 
| 400 | 
             
                "audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
         | 
| 401 | 
             
                "audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
         | 
| 402 | 
             
                "audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
         | 
| 403 | 
            -
                "audio_projector.net.0.bias": "model- | 
| 404 | 
            -
                "audio_projector.net.0.weight": "model- | 
| 405 | 
            -
                "audio_projector.net.2.bias": "model- | 
| 406 | 
            -
                "audio_projector.net.2.weight": "model- | 
| 407 | 
             
                "decoder.lm_head.weight": "model-00002-of-00002.safetensors",
         | 
| 408 | 
             
                "decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
         | 
| 409 | 
             
                "decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
         | 
| @@ -442,11 +442,11 @@ | |
| 442 | 
             
                "decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 443 | 
             
                "decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
         | 
| 444 | 
             
                "decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 445 | 
            -
                "decoder.model.layers.11.input_layernorm.weight": "model- | 
| 446 | 
            -
                "decoder.model.layers.11.mlp.down_proj.weight": "model- | 
| 447 | 
             
                "decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 448 | 
            -
                "decoder.model.layers.11.mlp.up_proj.weight": "model- | 
| 449 | 
            -
                "decoder.model.layers.11.post_attention_layernorm.weight": "model- | 
| 450 | 
             
                "decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
         | 
| 451 | 
             
                "decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 452 | 
             
                "decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
         | 
|  | |
| 390 | 
             
                "audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
         | 
| 391 | 
             
                "audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
         | 
| 392 | 
             
                "audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
         | 
| 393 | 
            +
                "audio_encoder.init_bn.bias": "model-00001-of-00002.safetensors",
         | 
| 394 | 
            +
                "audio_encoder.init_bn.num_batches_tracked": "model-00001-of-00002.safetensors",
         | 
| 395 | 
            +
                "audio_encoder.init_bn.running_mean": "model-00001-of-00002.safetensors",
         | 
| 396 | 
            +
                "audio_encoder.init_bn.running_var": "model-00001-of-00002.safetensors",
         | 
| 397 | 
            +
                "audio_encoder.init_bn.weight": "model-00001-of-00002.safetensors",
         | 
| 398 | 
             
                "audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
         | 
| 399 | 
             
                "audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
         | 
| 400 | 
             
                "audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
         | 
| 401 | 
             
                "audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
         | 
| 402 | 
             
                "audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
         | 
| 403 | 
            +
                "audio_projector.net.0.bias": "model-00001-of-00002.safetensors",
         | 
| 404 | 
            +
                "audio_projector.net.0.weight": "model-00001-of-00002.safetensors",
         | 
| 405 | 
            +
                "audio_projector.net.2.bias": "model-00001-of-00002.safetensors",
         | 
| 406 | 
            +
                "audio_projector.net.2.weight": "model-00001-of-00002.safetensors",
         | 
| 407 | 
             
                "decoder.lm_head.weight": "model-00002-of-00002.safetensors",
         | 
| 408 | 
             
                "decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
         | 
| 409 | 
             
                "decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
         | 
|  | |
| 442 | 
             
                "decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 443 | 
             
                "decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
         | 
| 444 | 
             
                "decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 445 | 
            +
                "decoder.model.layers.11.input_layernorm.weight": "model-00002-of-00002.safetensors",
         | 
| 446 | 
            +
                "decoder.model.layers.11.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
         | 
| 447 | 
             
                "decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 448 | 
            +
                "decoder.model.layers.11.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
         | 
| 449 | 
            +
                "decoder.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
         | 
| 450 | 
             
                "decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
         | 
| 451 | 
             
                "decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
         | 
| 452 | 
             
                "decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
         | 
    	
        modeling_midashenglm.py
    CHANGED
    
    | @@ -249,21 +249,12 @@ class Block(nn.Module): | |
| 249 | 
             
                    return x
         | 
| 250 |  | 
| 251 |  | 
| 252 | 
            -
             | 
| 253 | 
            -
             | 
| 254 | 
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 255 | 
            -
                    # rearrange(x, "b c f t -> b f c t")
         | 
| 256 | 
            -
                    # or
         | 
| 257 | 
            -
                    # rearrange(x, "b f c t -> b c f t")
         | 
| 258 | 
            -
                    return torch.permute(x, (0, 2, 1, 3))
         | 
| 259 |  | 
|  | |
|  | |
| 260 |  | 
| 261 | 
            -
            class AudioTransformer(nn.Module):
         | 
| 262 | 
            -
                def __init__(
         | 
| 263 | 
            -
                    self,
         | 
| 264 | 
            -
                    config: DashengConfig,
         | 
| 265 | 
            -
                ):
         | 
| 266 | 
            -
                    super().__init__()
         | 
| 267 | 
             
                    self.target_length = config.target_length
         | 
| 268 | 
             
                    self.embed_dim = config.embed_dim
         | 
| 269 | 
             
                    self.hop_length = config.hop_length
         | 
| @@ -282,13 +273,7 @@ class AudioTransformer(nn.Module): | |
| 282 | 
             
                        audio_transforms.AmplitudeToDB(top_db=120),
         | 
| 283 | 
             
                    )
         | 
| 284 |  | 
| 285 | 
            -
                    self.init_bn = nn. | 
| 286 | 
            -
                        # Rearrange("b c f t -> b f c t"),
         | 
| 287 | 
            -
                        RearranceReplace(),
         | 
| 288 | 
            -
                        nn.BatchNorm2d(config.n_mels, momentum=0.01),
         | 
| 289 | 
            -
                        # Rearrange("b f c t -> b c f t"),
         | 
| 290 | 
            -
                        RearranceReplace(),
         | 
| 291 | 
            -
                    )
         | 
| 292 |  | 
| 293 | 
             
                    self.patch_embed = AudioPatchEmbed(
         | 
| 294 | 
             
                        input_size=(config.n_mels, config.target_length),
         | 
| @@ -327,6 +312,8 @@ class AudioTransformer(nn.Module): | |
| 327 | 
             
                    )
         | 
| 328 | 
             
                    self.norm = norm_layer(config.embed_dim)
         | 
| 329 |  | 
|  | |
|  | |
| 330 | 
             
                def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         | 
| 331 | 
             
                    t = x.shape[-1]
         | 
| 332 | 
             
                    x = x + self.time_pos_embed[:, :, :, :t]
         | 
| @@ -357,7 +344,9 @@ class AudioTransformer(nn.Module): | |
| 357 | 
             
                    x = self.front_end(x)
         | 
| 358 | 
             
                    target_length_in_patches = self.target_length // 4
         | 
| 359 | 
             
                    x = x.unsqueeze(1)
         | 
|  | |
| 360 | 
             
                    x = self.init_bn(x)
         | 
|  | |
| 361 |  | 
| 362 | 
             
                    x = self.patch_embed(x)
         | 
| 363 | 
             
                    t = x.shape[-1]
         | 
| @@ -427,23 +416,21 @@ class DashengQwen25OmniModelInstructOutput(ModelOutput): | |
| 427 |  | 
| 428 | 
             
            class Decoder(PreTrainedModel, GenerationMixin):
         | 
| 429 | 
             
                config_class = Qwen2_5OmniTextConfig
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 430 |  | 
| 431 | 
             
                def __init__(self, config: Qwen2_5OmniTextConfig):
         | 
| 432 | 
             
                    super().__init__(config)
         | 
| 433 | 
            -
                    self.model = Qwen2_5OmniThinkerTextModel._from_config(
         | 
| 434 | 
            -
                        config,
         | 
| 435 | 
            -
                        attn_implementation="sdpa",  # TODO
         | 
| 436 | 
            -
                    )
         | 
| 437 | 
             
                    self.lm_head = nn.Linear(
         | 
| 438 | 
             
                        config.hidden_size,
         | 
| 439 | 
             
                        config.vocab_size,
         | 
| 440 | 
             
                        bias=False,
         | 
| 441 | 
             
                    )
         | 
| 442 | 
            -
                    # TODO fix dtype
         | 
| 443 | 
            -
                    self.lm_head.weight.data = self.lm_head.weight.data.to(
         | 
| 444 | 
            -
                        self.model.embed_tokens.weight.dtype
         | 
| 445 | 
            -
                    )
         | 
| 446 | 
            -
                    # TODO tie weight?
         | 
| 447 | 
             
                    self.post_init()
         | 
| 448 |  | 
| 449 | 
             
                def forward(
         | 
| @@ -481,30 +468,25 @@ class Decoder(PreTrainedModel, GenerationMixin): | |
| 481 |  | 
| 482 | 
             
            class DashengQwen25OmniModelInstruct(PreTrainedModel):
         | 
| 483 | 
             
                config_class = MiAudioLLMHFConfig
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 484 |  | 
| 485 | 
             
                def __init__(self, config: MiAudioLLMHFConfig):
         | 
| 486 | 
             
                    super().__init__(config)
         | 
| 487 |  | 
| 488 | 
            -
                     | 
| 489 | 
            -
                    lora = config.lora
         | 
| 490 | 
            -
                    subsample_factor = config.subsample_factor
         | 
| 491 | 
            -
             | 
| 492 | 
            -
                    self.subsample_factor = subsample_factor
         | 
| 493 | 
            -
                    self.lora = lora
         | 
| 494 | 
            -
                    # Encoder part
         | 
| 495 | 
            -
                    self.audio_encoder = AudioTransformer(config.audio_encoder_config)
         | 
| 496 | 
            -
                    assert lora != "encoder"
         | 
| 497 | 
            -
             | 
| 498 | 
            -
                    # decoder
         | 
| 499 | 
            -
                    self.decoder = Decoder(config.text_model_config)
         | 
| 500 | 
            -
                    assert lora != "decoder"
         | 
| 501 | 
            -
                    assert freeze is None
         | 
| 502 | 
            -
             | 
| 503 | 
            -
                    # audio projector
         | 
| 504 | 
             
                    self.audio_projector = AudioProjectorSubsample(
         | 
| 505 | 
             
                        self.audio_encoder.embed_dim,
         | 
| 506 | 
            -
                        config. | 
| 507 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
| 508 | 
             
                    )
         | 
| 509 |  | 
| 510 | 
             
                    self.post_init()
         | 
|  | |
| 249 | 
             
                    return x
         | 
| 250 |  | 
| 251 |  | 
| 252 | 
            +
            class AudioTransformer(PreTrainedModel):
         | 
| 253 | 
            +
                config_class = DashengConfig
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 254 |  | 
| 255 | 
            +
                def __init__(self, config: DashengConfig):
         | 
| 256 | 
            +
                    super().__init__(config)
         | 
| 257 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 258 | 
             
                    self.target_length = config.target_length
         | 
| 259 | 
             
                    self.embed_dim = config.embed_dim
         | 
| 260 | 
             
                    self.hop_length = config.hop_length
         | 
|  | |
| 273 | 
             
                        audio_transforms.AmplitudeToDB(top_db=120),
         | 
| 274 | 
             
                    )
         | 
| 275 |  | 
| 276 | 
            +
                    self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 277 |  | 
| 278 | 
             
                    self.patch_embed = AudioPatchEmbed(
         | 
| 279 | 
             
                        input_size=(config.n_mels, config.target_length),
         | 
|  | |
| 312 | 
             
                    )
         | 
| 313 | 
             
                    self.norm = norm_layer(config.embed_dim)
         | 
| 314 |  | 
| 315 | 
            +
                    self.post_init()
         | 
| 316 | 
            +
             | 
| 317 | 
             
                def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         | 
| 318 | 
             
                    t = x.shape[-1]
         | 
| 319 | 
             
                    x = x + self.time_pos_embed[:, :, :, :t]
         | 
|  | |
| 344 | 
             
                    x = self.front_end(x)
         | 
| 345 | 
             
                    target_length_in_patches = self.target_length // 4
         | 
| 346 | 
             
                    x = x.unsqueeze(1)
         | 
| 347 | 
            +
                    x = torch.permute(x, (0, 2, 1, 3))
         | 
| 348 | 
             
                    x = self.init_bn(x)
         | 
| 349 | 
            +
                    x = torch.permute(x, (0, 2, 1, 3))
         | 
| 350 |  | 
| 351 | 
             
                    x = self.patch_embed(x)
         | 
| 352 | 
             
                    t = x.shape[-1]
         | 
|  | |
| 416 |  | 
| 417 | 
             
            class Decoder(PreTrainedModel, GenerationMixin):
         | 
| 418 | 
             
                config_class = Qwen2_5OmniTextConfig
         | 
| 419 | 
            +
                _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
         | 
| 420 | 
            +
                _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
         | 
| 421 | 
            +
                _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
         | 
| 422 | 
            +
                _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
         | 
| 423 | 
            +
                _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
         | 
| 424 | 
            +
                _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
         | 
| 425 |  | 
| 426 | 
             
                def __init__(self, config: Qwen2_5OmniTextConfig):
         | 
| 427 | 
             
                    super().__init__(config)
         | 
| 428 | 
            +
                    self.model = Qwen2_5OmniThinkerTextModel._from_config(config)
         | 
|  | |
|  | |
|  | |
| 429 | 
             
                    self.lm_head = nn.Linear(
         | 
| 430 | 
             
                        config.hidden_size,
         | 
| 431 | 
             
                        config.vocab_size,
         | 
| 432 | 
             
                        bias=False,
         | 
| 433 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 434 | 
             
                    self.post_init()
         | 
| 435 |  | 
| 436 | 
             
                def forward(
         | 
|  | |
| 468 |  | 
| 469 | 
             
            class DashengQwen25OmniModelInstruct(PreTrainedModel):
         | 
| 470 | 
             
                config_class = MiAudioLLMHFConfig
         | 
| 471 | 
            +
                _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
         | 
| 472 | 
            +
                _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
         | 
| 473 | 
            +
                _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
         | 
| 474 | 
            +
                _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
         | 
| 475 | 
            +
                _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
         | 
| 476 | 
            +
                _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
         | 
| 477 |  | 
| 478 | 
             
                def __init__(self, config: MiAudioLLMHFConfig):
         | 
| 479 | 
             
                    super().__init__(config)
         | 
| 480 |  | 
| 481 | 
            +
                    self.audio_encoder = AudioTransformer._from_config(config.audio_encoder_config)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 482 | 
             
                    self.audio_projector = AudioProjectorSubsample(
         | 
| 483 | 
             
                        self.audio_encoder.embed_dim,
         | 
| 484 | 
            +
                        config.text_config.hidden_size,
         | 
| 485 | 
            +
                        config.subsample_factor,
         | 
| 486 | 
            +
                    )
         | 
| 487 | 
            +
                    self.decoder = Decoder._from_config(
         | 
| 488 | 
            +
                        config.text_config,
         | 
| 489 | 
            +
                        attn_implementation=config._attn_implementation,
         | 
| 490 | 
             
                    )
         | 
| 491 |  | 
| 492 | 
             
                    self.post_init()
         | 
    	
        processing_midashenglm.py
    CHANGED
    
    | @@ -55,32 +55,35 @@ class MiAudioLLMProcessor(ProcessorMixin): | |
| 55 | 
             
                    tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
         | 
| 56 | 
             
                    model_subsampling: int = 5,
         | 
| 57 | 
             
                    chat_template: str | None = None,
         | 
| 58 | 
            -
                     | 
| 59 | 
            -
                     | 
| 60 | 
            -
                     | 
| 61 | 
            -
                    audio_eos_token: str = "<|audio_eos|>",
         | 
| 62 | 
             
                ):
         | 
| 63 | 
            -
                    if chat_template is None:
         | 
| 64 | 
            -
                        chat_template = self.default_chat_template
         | 
| 65 | 
             
                    assert tokenizer is not None, "Tokenizer Needs to be passed"
         | 
| 66 | 
            -
                     | 
| 67 | 
            -
                         | 
| 68 | 
             
                    )
         | 
| 69 | 
            -
                     | 
| 70 | 
            -
             | 
| 71 | 
            -
                        tokenizer.audio_bos_token
         | 
| 72 | 
            -
                        if hasattr(tokenizer, "audio_bos_token")
         | 
| 73 | 
            -
                        else audio_bos_token
         | 
| 74 | 
             
                    )
         | 
| 75 | 
            -
                     | 
| 76 | 
            -
                        tokenizer. | 
| 77 | 
            -
                        if hasattr(tokenizer, "audio_eos_token")
         | 
| 78 | 
            -
                        else audio_eos_token
         | 
| 79 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 80 | 
             
                    self.model_subsampling = model_subsampling
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    if feature_extractor is not None | 
| 83 | 
            -
                        feature_extractor.do_normalize  | 
|  | |
|  | |
|  | |
| 84 | 
             
                    super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
         | 
| 85 |  | 
| 86 | 
             
                def __call__(
         | 
|  | |
| 55 | 
             
                    tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
         | 
| 56 | 
             
                    model_subsampling: int = 5,
         | 
| 57 | 
             
                    chat_template: str | None = None,
         | 
| 58 | 
            +
                    audio_token: str | None = None,
         | 
| 59 | 
            +
                    audio_bos_token: str | None = None,
         | 
| 60 | 
            +
                    audio_eos_token: str | None = None,
         | 
|  | |
| 61 | 
             
                ):
         | 
|  | |
|  | |
| 62 | 
             
                    assert tokenizer is not None, "Tokenizer Needs to be passed"
         | 
| 63 | 
            +
                    assert audio_token is not None or hasattr(tokenizer, "audio_token"), (
         | 
| 64 | 
            +
                        "Either `audio_token` must be provided or tokenizer must have `audio_token` attribute."
         | 
| 65 | 
             
                    )
         | 
| 66 | 
            +
                    assert audio_bos_token is not None or hasattr(tokenizer, "audio_bos_token"), (
         | 
| 67 | 
            +
                        "Either `audio_bos_token` must be provided or tokenizer must have `audio_bos_token` attribute."
         | 
|  | |
|  | |
|  | |
| 68 | 
             
                    )
         | 
| 69 | 
            +
                    assert audio_eos_token is not None or hasattr(tokenizer, "audio_eos_token"), (
         | 
| 70 | 
            +
                        "Either `audio_eos_token` must be provided or tokenizer must have `audio_eos_token` attribute."
         | 
|  | |
|  | |
| 71 | 
             
                    )
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if chat_template is None:
         | 
| 74 | 
            +
                        chat_template = self.default_chat_template
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.audio_token: str = audio_token or tokenizer.audio_token
         | 
| 77 | 
            +
                    self.audio_bos_token = audio_bos_token or tokenizer.audio_bos_token
         | 
| 78 | 
            +
                    self.audio_eos_token = audio_eos_token or tokenizer.audio_eos_token
         | 
| 79 | 
            +
                    self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
         | 
| 80 | 
             
                    self.model_subsampling = model_subsampling
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    if feature_extractor is not None:
         | 
| 83 | 
            +
                        assert not feature_extractor.do_normalize, (
         | 
| 84 | 
            +
                            "This model does not use normalization. Please set `do_normalize=False` in the feature extractor."
         | 
| 85 | 
            +
                        )
         | 
| 86 | 
            +
             | 
| 87 | 
             
                    super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
         | 
| 88 |  | 
| 89 | 
             
                def __call__(
         | 
