from collections import OrderedDict from typing import List, Union, Dict import torch import torch.nn as nn from torch import Tensor from torch.nn.utils.rnn import pad_sequence HIDDEN_DIM = 8 class Model(nn.Module): def __init__(self): super().__init__() # The model needs to be a nn.Module for finetuning, not required for representation extraction self.model1 = nn.Linear(1, HIDDEN_DIM) self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) def forward(self, wavs, upstream_feature_selection="hidden_states"): # You can do task-specified pre- / post-processing based on upstream_feature_selection hidden = self.model1(wavs) # hidden: (batch_size, max_len, hidden_dim) feature = self.model2(hidden) # feature: (batch_size, max_len, hidden_dim) return [hidden, feature] class UpstreamExpert(nn.Module): def __init__( self, ckpt: str = "./", upstream_feature_selection: str = "hidden_states", **kwargs): """ Args: ckpt: The checkpoint path for loading your pretrained weights. Should be fixed as for SUPERB Challenge. upstream_feature_selection: The value could be 'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks). You can use it to control which task-specified pre- / post-processing to do. """ super().__init__() = "[Example UpstreamExpert]" self.upstream_feature_selection = upstream_feature_selection # You can use ckpt to load your pretrained weights ckpt = torch.load(ckpt, map_location="cpu") self.model = Model() self.model.load_state_dict(ckpt) def get_downsample_rates(self, key: str) -> int: """ Since we do not do any downsampling in this example upstream All keys' corresponding representations have downsample rate of 1 Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) """ return 1 def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: """ When the returning Dict contains the List with more than one Tensor, those Tensors should be in the same shape to train a weighted-sum on them. """ wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1) # wavs: (batch_size, max_len, 1) hidden_states = self.model(wavs, upstream_feature_selection=self.upstream_feature_selection) # Deprecated! Do not do any task-specified postprocess below # You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do. # The "hidden_states" key will be used as default in many cases # Others keys in this example are presented for SUPERB Challenge return { "hidden_states": hidden_states, }