lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.57 kB
import argparse
import logging
from copy import deepcopy
from dataclasses import dataclass, is_dataclass
import torch
from s3prl.util.download import _urls_to_filepaths
from s3prl.util.pseudo_data import get_pseudo_wavs
logger = logging.getLogger(__name__)
def load_fairseq_ckpt(source: str, **override):
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from omegaconf import OmegaConf
source = str(source)
if source.startswith("http"):
fairseq_path = _urls_to_filepaths(source)
else:
fairseq_path = source
state = load_checkpoint_to_cpu(fairseq_path, arg_overrides=override)
cfg = OmegaConf.to_container(state["cfg"])
assert type(cfg) == dict
return state, cfg
def merge_with_parent(dc: dataclass, cfg: dict):
assert is_dataclass(dc)
assert type(cfg) == dict
cfg = deepcopy(cfg)
def fix_cfg(cfg):
target_keys = set(dc.__dataclass_fields__.keys())
for k in list(cfg.keys()):
if k not in target_keys:
del cfg[k]
fix_cfg(cfg)
assert len(cfg) > 0
return dc(**cfg)
def extract_hidden_states(model):
model.eval()
with torch.no_grad():
return model(get_pseudo_wavs())["hidden_states"]
def are_same_models(model1, model2):
hs1 = extract_hidden_states(model1)
hs2 = extract_hidden_states(model2)
for h1, h2 in zip(hs1, hs2):
assert torch.allclose(h1, h2)
def models_all_close(*models):
assert len(models) > 1
for model in models[1:]:
are_same_models(models[0], model)