| from transformers import HubertModel |
| from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder |
|
|
| from .configuration import SfiHuBERTConfig |
| from .continuous_filters import FrequencyDomainRFFImplicitFilter |
| from .conv_any_stride import FreqRespSampConv1d |
|
|
|
|
| class SfiHuBERTFeatureEncoder(HubertFeatureEncoder): |
| def __init__(self, config: SfiHuBERTConfig): |
| super().__init__(config) |
| out_channels = self.conv_layers[0].conv.out_channels |
| self.conv_layers[0].conv = FreqRespSampConv1d( |
| in_channels=1, |
| out_channels=out_channels, |
| ContFilterType=FrequencyDomainRFFImplicitFilter, |
| filter_params=config.latent_filter_params, |
| n_samples=640, |
| ) |
|
|
| def forward(self, *args, **kwargs): |
| |
| return super().forward(*args, **kwargs) |
|
|
|
|
| class SfiHuBERTModel(HubertModel): |
| config_class = SfiHuBERTConfig |
|
|
| def __init__(self, config: SfiHuBERTConfig): |
| super().__init__(config) |
| self.config = config |
| self.feature_extractor = SfiHuBERTFeatureEncoder(config) |
|
|
| def forward(self, *args, **kwargs): |
| |
| return super().forward(*args, **kwargs) |
|
|
| def set_sample_rate(self, sample_rate): |
| sample_rate = str(int(sample_rate)) |
| if sample_rate not in self.config.sfi_conv_parameters: |
| raise ValueError( |
| f"Sample rate {sample_rate} not in the list of allowed sample rates." |
| ) |
| self.feature_extractor.conv_layers[0].conv.prepare( |
| sample_rate=int(sample_rate), **self.config.sfi_conv_parameters[sample_rate] |
| ) |
|
|