import torch from torch import Tensor, nn from transformers import PreTrainedModel from .config import AdapterConfig class Model(nn.Module): def __init__( self, num_channels: int, num_filters: int, window_length: int, stride: int, ): super().__init__() self.stride = stride padding = window_length // 2 - stride // 2 self.conv = nn.Conv1d( in_channels=num_channels, out_channels=num_filters, kernel_size=window_length, stride=stride, padding=padding, padding_mode="reflect", bias=False, ) self.decode = nn.ConvTranspose1d( in_channels=num_filters, out_channels=num_channels, kernel_size=window_length, stride=stride, padding=padding, bias=False, ) def encode(self, x: Tensor) -> Tensor: return torch.tanh(self.conv(x)) class Adapter(PreTrainedModel): config_class = AdapterConfig def __init__(self, config: AdapterConfig): super().__init__(config) self.model = Model( num_channels=2, num_filters=128, window_length=128, stride=64 ) def encode(self, x): return self.model.encode(x) def decode(self, x): return self.model.decode(x)