import numpy as np import torch from typing import Dict def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]): for k in orig_pt_state_dict: try: np.testing.assert_allclose( orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy() ) except: raise ValueError("There are problems in the parameter population process. Cannot proceed :(")