| import argparse | |
| import logging | |
| from copy import deepcopy | |
| from dataclasses import dataclass, is_dataclass | |
| import torch | |
| from s3prl.util.download import _urls_to_filepaths | |
| from s3prl.util.pseudo_data import get_pseudo_wavs | |
| logger = logging.getLogger(__name__) | |
| def load_fairseq_ckpt(source: str, **override): | |
| from fairseq.checkpoint_utils import load_checkpoint_to_cpu | |
| from omegaconf import OmegaConf | |
| source = str(source) | |
| if source.startswith("http"): | |
| fairseq_path = _urls_to_filepaths(source) | |
| else: | |
| fairseq_path = source | |
| state = load_checkpoint_to_cpu(fairseq_path, arg_overrides=override) | |
| cfg = OmegaConf.to_container(state["cfg"]) | |
| assert type(cfg) == dict | |
| return state, cfg | |
| def merge_with_parent(dc: dataclass, cfg: dict): | |
| assert is_dataclass(dc) | |
| assert type(cfg) == dict | |
| cfg = deepcopy(cfg) | |
| def fix_cfg(cfg): | |
| target_keys = set(dc.__dataclass_fields__.keys()) | |
| for k in list(cfg.keys()): | |
| if k not in target_keys: | |
| del cfg[k] | |
| fix_cfg(cfg) | |
| assert len(cfg) > 0 | |
| return dc(**cfg) | |
| def extract_hidden_states(model): | |
| model.eval() | |
| with torch.no_grad(): | |
| return model(get_pseudo_wavs())["hidden_states"] | |
| def are_same_models(model1, model2): | |
| hs1 = extract_hidden_states(model1) | |
| hs2 = extract_hidden_states(model2) | |
| for h1, h2 in zip(hs1, hs2): | |
| assert torch.allclose(h1, h2) | |
| def models_all_close(*models): | |
| assert len(models) > 1 | |
| for model in models[1:]: | |
| are_same_models(models[0], model) | |