sayakpaul's picture
sayakpaul HF staff
apply styling.
3304f7d
from typing import Dict
import numpy as np
import torch
def run_assertion(
orig_pt_state_dict: Dict[str, torch.Tensor],
pt_state_dict_from_tf: Dict[str, torch.Tensor],
):
for k in orig_pt_state_dict:
try:
np.testing.assert_allclose(
orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy()
)
except:
raise ValueError(
"There are problems in the parameter population process. Cannot proceed :("
)