| |
| |
| |
| |
|
|
| import copy |
| import os |
| from typing import Union |
|
|
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class AudioConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of an Audio Encoder Model. |
| It is used to instantiate an audio encoder according to the specified arguments, |
| defining the model architecture. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. |
| Read the documentation from [`PretrainedConfig`] for more information. |
| |
| Args: |
| speech_encoder (`str`, *optional*, defaults to `"whisper-base"`): |
| Path or name of the speech encoder model. |
| speech_encoder_type (`str`, *optional*, defaults to `"whisper"`): |
| Type of speech encoder to use. |
| speech_projector_type (`str`, *optional*, defaults to `"linear"`): |
| Type of speech projector to use for feature alignment. |
| speech_encoder_ds_rate (`int`, *optional*, defaults to 5): |
| Downsampling rate for speech features. |
| speech_encoder_hidden_size (`int`, *optional*, defaults to 512): |
| Hidden size of the speech encoder. |
| mel_bins (`int`, *optional*, defaults to 80): |
| Number of mel-frequency bins for spectrogram features. |
| sample_rate (`int`, *optional*, defaults to 16000): |
| Audio sample rate in Hz. |
| frame_length (`float`, *optional*, defaults to 25.0): |
| Frame length in milliseconds for audio processing. |
| frame_shift (`float`, *optional*, defaults to 10.0): |
| Frame shift in milliseconds for audio processing. |
| use_beats (`bool`, *optional*, defaults to False): |
| Whether to use BEATs model for audio feature extraction. |
| beats_model_path (`str`, *optional*, defaults to None): |
| Path to BEATs model if use_beats is True. |
| whisper_config (`dict`, *optional*, defaults to None): |
| Configuration dictionary for Whisper model parameters. |
| """ |
|
|
| model_type = 'audio_encoder' |
|
|
| def __init__( |
| self, |
| speech_encoder="whisper-base", |
| speech_encoder_type="whisper", |
| speech_projector_type="linear", |
| speech_encoder_ds_rate=5, |
| speech_encoder_hidden_size=1280, |
| mel_bins=80, |
| sample_rate=16000, |
| frame_length=25.0, |
| frame_shift=10.0, |
| use_beats=False, |
| beats_model_path=None, |
| whisper_config=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.speech_encoder = speech_encoder |
| self.speech_encoder_type = speech_encoder_type |
| self.speech_projector_type = speech_projector_type |
| self.speech_encoder_ds_rate = speech_encoder_ds_rate |
| self.speech_encoder_hidden_size = speech_encoder_hidden_size |
| self.mel_bins = mel_bins |
| self.sample_rate = sample_rate |
| self.frame_length = frame_length |
| self.frame_shift = frame_shift |
| self.use_beats = use_beats |
| self.beats_model_path = beats_model_path |
| self.whisper_config = whisper_config or {} |
|
|
| logger.info(f'Audio Config - Speech Encoder: {self.speech_encoder}') |
| logger.info(f'Audio Config - Encoder Type: {self.speech_encoder_type}') |
| logger.info(f'Audio Config - Projector Type: {self.speech_projector_type}') |
| logger.info(f'Audio Config - Downsampling Rate: {self.speech_encoder_ds_rate}') |
| logger.info(f'Audio Config - Hidden Size: {self.speech_encoder_hidden_size}') |
| logger.info(f'Audio Config - Mel Bins: {self.mel_bins}') |
| logger.info(f'Audio Config - Sample Rate: {self.sample_rate}') |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': |
| cls._set_token_in_kwargs(kwargs) |
|
|
| config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
| return cls.from_dict(config_dict, **kwargs) |
|
|
| def to_dict(self): |
| """ |
| Serializes this instance to a Python dictionary. |
| """ |
| output = copy.deepcopy(self.__dict__) |
| output['model_type'] = self.__class__.model_type |
| return output |
|
|