from transformers import PreTrainedModel, AutoConfig, AutoModel from .model import SincNet from .config import SincNetConfig class SincNetModel(PreTrainedModel): config_class = SincNetConfig base_model_prefix = "sincnet" def __init__(self, config: SincNetConfig): super().__init__(config) self.model = SincNet( sinc_filter_stride=config.stride, num_sinc_filters=config.num_sinc_filters, sinc_filter_length=config.sinc_filter_length, num_conv_filters=config.num_conv_filters, conv_filter_length=config.conv_filter_length, pool_kernel_size=config.pool_kernel_size, pool_stride=config.pool_stride, sample_rate=config.sample_rate, ) def forward(self, waveforms): return self.model(waveforms) AutoConfig.register('sincnet', SincNetConfig) AutoModel.register(SincNetConfig, SincNetModel)