Spaces:
Runtime error
Runtime error
import numpy as np | |
from loguru import logger | |
import torch | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
from temps.temps import TempsModule | |
def test(): | |
nn_features = EncoderPhotometry() | |
nn_features.load_state_dict( | |
torch.load("data/models/modelF_DA.pt", map_location=torch.device("cpu")) | |
) | |
nn_z = MeasureZ(num_gauss=6) | |
nn_z.load_state_dict( | |
torch.load("data/models/modelZ_DA.pt", map_location=torch.device("cpu")) | |
) | |
temps_module = TempsModule(nn_features, nn_z) | |
col = np.array( | |
[0.54804805, 1.81142339, 0.63354394, 0.7356338, 1.3578122, 0.90108565] | |
) | |
ztrue = 0.4446 | |
z, pz, odds = temps_module.get_pz( | |
input_data=torch.Tensor(col).unsqueeze(0), return_pz=True, return_flag=True | |
) | |
zdiff = np.abs(z - ztrue).mean() | |
logger.info(f"zdiff: {zdiff}") | |
logger.info("test passed") | |
assert zdiff < 0.01 | |
test() | |
# %% | |