|
import torch |
|
import fairseq |
|
from packaging import version |
|
import torch.nn.functional as F |
|
from fairseq import tasks |
|
from fairseq.checkpoint_utils import load_checkpoint_to_cpu |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from omegaconf import OmegaConf |
|
from s3prl.upstream.interfaces import UpstreamBase |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
def load_model(filepath): |
|
state = torch.load(filepath, map_location=lambda storage, loc: storage) |
|
|
|
state["cfg"] = OmegaConf.create(state["cfg"]) |
|
|
|
if "args" in state and state["args"] is not None: |
|
cfg = convert_namespace_to_omegaconf(state["args"]) |
|
elif "cfg" in state and state["cfg"] is not None: |
|
cfg = state["cfg"] |
|
else: |
|
raise RuntimeError( |
|
f"Neither args nor cfg exist in state keys = {state.keys()}" |
|
) |
|
|
|
task = tasks.setup_task(cfg.task) |
|
if "task_state" in state: |
|
task.load_state_dict(state["task_state"]) |
|
|
|
model = task.build_model(cfg.model) |
|
|
|
return model, cfg, task |
|
|
|
|
|
|
|
|
|
|
|
class UpstreamExpert(UpstreamBase): |
|
def __init__(self, ckpt, **kwargs): |
|
super().__init__(**kwargs) |
|
assert version.parse(fairseq.__version__) > version.parse( |
|
"0.10.2" |
|
), "Please install the fairseq master branch." |
|
|
|
model, cfg, task = load_model(ckpt) |
|
self.model = model |
|
self.task = task |
|
|
|
if len(self.hooks) == 0: |
|
module_name = "self.model.encoder.layers" |
|
for module_id in range(len(eval(module_name))): |
|
self.add_hook( |
|
f"{module_name}[{module_id}]", |
|
lambda input, output: input[0].transpose(0, 1), |
|
) |
|
self.add_hook("self.model.encoder", lambda input, output: output[0]) |
|
|
|
def forward(self, wavs): |
|
if self.task.cfg.normalize: |
|
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] |
|
|
|
device = wavs[0].device |
|
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) |
|
wav_padding_mask = ~torch.lt( |
|
torch.arange(max(wav_lengths)).unsqueeze(0).to(device), |
|
wav_lengths.unsqueeze(1), |
|
) |
|
padded_wav = pad_sequence(wavs, batch_first=True) |
|
|
|
features, feat_padding_mask = self.model.extract_features( |
|
padded_wav, |
|
padding_mask=wav_padding_mask, |
|
mask=None, |
|
) |
|
return { |
|
"default": features, |
|
} |
|
|
|
|