Joseph Feng
enable task-specified preprocessing
de6da40
from collections import OrderedDict
from typing import List, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
HIDDEN_DIM = 8
class Model(nn.Module):
def __init__(self):
super().__init__()
# The model needs to be a nn.Module for finetuning, not required for representation extraction
self.model1 = nn.Linear(1, HIDDEN_DIM)
self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
def forward(self, wavs, upstream_feature_selection="hidden_states"):
# You can do task-specified pre- / post-processing based on upstream_feature_selection
hidden = self.model1(wavs)
# hidden: (batch_size, max_len, hidden_dim)
feature = self.model2(hidden)
# feature: (batch_size, max_len, hidden_dim)
return [hidden, feature]
class UpstreamExpert(nn.Module):
def __init__(
self,
ckpt: str = "./model.pt",
upstream_feature_selection: str = "hidden_states",
**kwargs):
"""
Args:
ckpt:
The checkpoint path for loading your pretrained weights.
Should be fixed as model.pt for SUPERB Challenge.
upstream_feature_selection:
The value could be
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
You can use it to control which task-specified pre- / post-processing to do.
"""
super().__init__()
self.name = "[Example UpstreamExpert]"
self.upstream_feature_selection = upstream_feature_selection
# You can use ckpt to load your pretrained weights
ckpt = torch.load(ckpt, map_location="cpu")
self.model = Model()
self.model.load_state_dict(ckpt)
def get_downsample_rates(self, key: str) -> int:
"""
Since we do not do any downsampling in this example upstream
All keys' corresponding representations have downsample rate of 1
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
"""
return 1
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
"""
When the returning Dict contains the List with more than one Tensor,
those Tensors should be in the same shape to train a weighted-sum on them.
"""
wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1)
# wavs: (batch_size, max_len, 1)
hidden_states = self.model(wavs, upstream_feature_selection=self.upstream_feature_selection)
# Deprecated! Do not do any task-specified postprocess below
# You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
# The "hidden_states" key will be used as default in many cases
# Others keys in this example are presented for SUPERB Challenge
return {
"hidden_states": hidden_states,
}