import torch import fairseq from packaging import version import torch.nn.functional as F from fairseq import tasks from fairseq.checkpoint_utils import load_checkpoint_to_cpu from fairseq.dataclass.utils import convert_namespace_to_omegaconf from omegaconf import OmegaConf from s3prl.upstream.interfaces import UpstreamBase from torch.nn.utils.rnn import pad_sequence def load_model(filepath): state = torch.load(filepath, map_location=lambda storage, loc: storage) # state = load_checkpoint_to_cpu(filepath) state["cfg"] = OmegaConf.create(state["cfg"]) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) model = task.build_model(cfg.model) return model, cfg, task ################### # UPSTREAM EXPERT # ################### class UpstreamExpert(UpstreamBase): def __init__(self, ckpt, **kwargs): super().__init__(**kwargs) assert version.parse(fairseq.__version__) > version.parse( "0.10.2" ), "Please install the fairseq master branch." model, cfg, task = load_model(ckpt) self.model = model self.task = task if len(self.hooks) == 0: module_name = "self.model.encoder.layers" for module_id in range(len(eval(module_name))): self.add_hook( f"{module_name}[{module_id}]", lambda input, output: input[0].transpose(0, 1), ) self.add_hook("self.model.encoder", lambda input, output: output[0]) def forward(self, wavs): if self.task.cfg.normalize: wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] device = wavs[0].device wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) wav_padding_mask = torch.arange(max(wav_lengths)).unsqueeze(0).to(device), wav_lengths.unsqueeze(1), ) padded_wav = pad_sequence(wavs, batch_first=True) features, feat_padding_mask = self.model.extract_features( padded_wav, padding_mask=wav_padding_mask, mask=None, ) return { "default": features, }