File size: 922 Bytes
21a7d1b
 
 
 
 
 
 
 
 
 
c0ee1e2
 
 
21a7d1b
c0ee1e2
 
 
21a7d1b
 
 
c0ee1e2
 
 
999af9c
21a7d1b
c0ee1e2
668e440
c0ee1e2
21a7d1b
37eb1ba
21a7d1b
c0ee1e2
21a7d1b
 
 
 
 
c0ee1e2
21a7d1b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
from loguru import logger
import torch

from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule


def test():
    nn_features = EncoderPhotometry()
    nn_features.load_state_dict(
        torch.load("data/models/modelF_DA.pt", map_location=torch.device("cpu"))
    )
    nn_z = MeasureZ(num_gauss=6)
    nn_z.load_state_dict(
        torch.load("data/models/modelZ_DA.pt", map_location=torch.device("cpu"))
    )

    temps_module = TempsModule(nn_features, nn_z)

    col = np.array(
        [0.54804805, 1.81142339, 0.63354394, 0.7356338, 1.3578122, 0.90108565]
    )
    ztrue = 0.4446

    z, pz, odds = temps_module.get_pz(
        input_data=torch.Tensor(col).unsqueeze(0), return_pz=True, return_flag=True
    )

    zdiff = np.abs(z - ztrue).mean()

    logger.info(f"zdiff: {zdiff}")
    logger.info("test passed")

    assert zdiff < 0.01


test()


# %%