from typing import Dict import numpy as np import torch def run_assertion( orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor], ): for k in orig_pt_state_dict: try: np.testing.assert_allclose( orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy() ) except: raise ValueError( "There are problems in the parameter population process. Cannot proceed :(" )