hbhzm's picture
Upload 625 files
3ea26d1 verified
raw
history blame
18.2 kB
"""Chemprop unit tests for chemprop/models/loss.py"""
import numpy as np
import pytest
import torch
from chemprop.nn.metrics import (
SID,
BCELoss,
BinaryMCCLoss,
BoundedMSE,
CrossEntropyLoss,
DirichletLoss,
EvidentialLoss,
MulticlassMCCLoss,
MVELoss,
Wasserstein,
)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,mse",
[
(
torch.tensor([[-3, 2], [1, -1]], dtype=torch.float),
torch.zeros([2, 2], dtype=torch.float),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2, 2], dtype=torch.bool),
torch.zeros([2, 2], dtype=torch.bool),
torch.tensor(3.75000, dtype=torch.float),
),
(
torch.tensor([[-3, 2], [1, -1]], dtype=torch.float),
torch.zeros([2, 2], dtype=torch.float),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2, 2], dtype=torch.bool),
torch.ones([2, 2], dtype=torch.bool),
torch.tensor(2.5000, dtype=torch.float),
),
(
torch.tensor([[-3, 2], [1, -1]], dtype=torch.float),
torch.zeros([2, 2], dtype=torch.float),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.ones([2, 2], dtype=torch.bool),
torch.zeros([2, 2], dtype=torch.bool),
torch.tensor(1.25000, dtype=torch.float),
),
],
)
def test_BoundedMSE(preds, targets, mask, weights, task_weights, lt_mask, gt_mask, mse):
"""
Testing the bounded_mse loss function
"""
bmse_loss = BoundedMSE(task_weights)
loss = bmse_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, mse)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,likelihood",
[
(
torch.tensor([[0, 1]], dtype=torch.float),
torch.zeros([1, 1]),
torch.ones([1, 2], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.zeros([2], dtype=torch.bool),
torch.zeros([2], dtype=torch.bool),
torch.tensor(0.39894228, dtype=torch.float),
)
],
)
def test_MVE(preds, targets, mask, weights, task_weights, lt_mask, gt_mask, likelihood):
"""
Tests the normal_mve loss function
"""
mve_loss = MVELoss(task_weights)
nll_calc = mve_loss(preds, targets, mask, weights, lt_mask, gt_mask)
likelihood_calc = np.exp(-1 * nll_calc)
torch.testing.assert_close(likelihood_calc, likelihood)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,v_kl,expected_loss",
[
(
torch.tensor([[[2, 2]]]),
torch.ones([1, 1]),
torch.ones([1, 2], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
0,
torch.tensor(0.6, dtype=torch.float),
),
(
torch.tensor([[[2, 2]]]),
torch.ones([1, 1]),
torch.ones([1, 2], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
0.2,
torch.tensor(0.63862943, dtype=torch.float),
),
],
)
def test_BinaryDirichlet(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, v_kl, expected_loss
):
"""
Test on the dirichlet loss function for classification.
Note these values were not hand derived, just testing for
dimensional consistency.
"""
binary_dirichlet_loss = DirichletLoss(task_weights=task_weights, v_kl=v_kl)
loss = binary_dirichlet_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,",
[
(
torch.ones([1, 1]),
torch.ones([1, 1]),
torch.ones([1, 2], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
)
],
)
def test_BinaryDirichlet_wrong_dimensions(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask
):
"""
Test on the dirichlet loss function for classification
for dimension errors.
"""
with pytest.raises(IndexError):
binary_dirichlet_loss = DirichletLoss(task_weights)
binary_dirichlet_loss(preds, targets, mask, weights, lt_mask, gt_mask)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,v_kl,expected_loss",
[
(
torch.tensor([[[0.2, 0.1, 0.3], [0.1, 0.3, 0.1]], [[1.2, 0.5, 1.7], [1.1, 1.4, 0.8]]]),
torch.tensor([[0, 0], [1, 1]]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2], dtype=torch.bool),
torch.zeros([2], dtype=torch.bool),
0.2,
torch.tensor(1.868991, dtype=torch.float),
),
(
torch.tensor([[[0.2, 0.1, 0.3], [0.1, 0.3, 0.1]], [[1.2, 0.5, 1.7], [1.1, 1.4, 0.8]]]),
torch.tensor([[0, 0], [1, 1]]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2], dtype=torch.bool),
torch.zeros([2], dtype=torch.bool),
0.0,
torch.tensor(1.102344, dtype=torch.float),
),
],
)
def test_MulticlassDirichlet(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, v_kl, expected_loss
):
"""
Test on the dirichlet loss function for classification.
Note these values were not hand derived, just testing for
dimensional consistency.
"""
multiclass_dirichlet_loss = DirichletLoss(task_weights=task_weights, v_kl=v_kl)
loss = multiclass_dirichlet_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,v_kl,expected_loss",
[
(
torch.tensor([[2, 2, 2, 2]]),
torch.ones([1, 1]),
torch.ones([1, 1], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
0,
torch.tensor(1.56893861, dtype=torch.float),
),
(
torch.tensor([[2, 2, 2, 2]]),
torch.ones([1, 1]),
torch.ones([1, 1], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
0.2,
torch.tensor(2.768938541, dtype=torch.float),
),
],
)
def test_Evidential(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, v_kl, expected_loss
):
"""
Test on the evidential loss function for classification.
Note these values were not hand derived, just testing for
dimensional consistency.
"""
evidential_loss = EvidentialLoss(task_weights=task_weights, v_kl=v_kl)
loss = evidential_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask",
[
(
torch.ones([2, 2]),
torch.ones([2, 2]),
torch.ones([1, 1], dtype=torch.bool),
torch.ones([1]),
torch.ones([1]),
torch.zeros([1], dtype=torch.bool),
torch.zeros([1], dtype=torch.bool),
)
],
)
def test_Evidential_wrong_dimensions(preds, targets, mask, weights, task_weights, lt_mask, gt_mask):
"""
Test on the Evidential loss function for classification
for dimension errors.
"""
evidential_loss = EvidentialLoss(task_weights)
with pytest.raises(ValueError):
evidential_loss(preds, targets, mask, weights, lt_mask, gt_mask)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,expected_loss",
[
(
torch.tensor([2, 2], dtype=torch.float),
torch.ones([2], dtype=torch.float),
torch.ones([2], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.zeros([2], dtype=torch.bool),
torch.zeros([2], dtype=torch.bool),
torch.tensor(0.126928, dtype=torch.float),
),
(
torch.tensor([0.5, 0.5], dtype=torch.float),
torch.ones([2], dtype=torch.float),
torch.ones([2], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.zeros([2], dtype=torch.bool),
torch.zeros([2], dtype=torch.bool),
torch.tensor(0.474077, dtype=torch.float),
),
],
)
def test_BCE(preds, targets, mask, weights, task_weights, lt_mask, gt_mask, expected_loss):
"""
Test on the BCE loss function for classification.
"""
bce_loss = BCELoss(task_weights)
loss = bce_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,expected_loss",
[
(
torch.tensor([[[1.2, 0.5, 0.7], [-0.1, 0.3, 0.1]], [[1.2, 0.5, 0.7], [1.1, 1.3, 1.1]]]),
torch.tensor([[1, 0], [1, 2]]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2, 2], dtype=torch.bool),
torch.tensor(1.34214, dtype=torch.float),
),
(
torch.tensor([[[1.2, 1.5, 0.7], [-0.1, 2.3, 1.1]], [[1.2, 1.5, 1.7], [2.1, 1.3, 1.1]]]),
torch.tensor([[1, 1], [2, 2]], dtype=torch.float64),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2, 2], dtype=torch.bool),
torch.tensor(0.899472, dtype=torch.float),
),
],
)
def test_CrossEntropy(preds, targets, mask, weights, task_weights, lt_mask, gt_mask, expected_loss):
"""
Test on the CE loss function for classification.
Note these values were not hand derived, just testing for
dimensional consistency.
"""
cross_entropy_loss = CrossEntropyLoss(task_weights)
loss = cross_entropy_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,expected_loss",
[
(
torch.tensor([0, 1, 1, 0]),
torch.tensor([0, 1, 1, 0]),
torch.ones([4], dtype=torch.bool),
torch.ones(1),
torch.ones(4),
torch.zeros([1, 4], dtype=torch.bool),
torch.zeros([1, 4], dtype=torch.bool),
torch.tensor(0, dtype=torch.float),
),
(
torch.tensor([0, 1, 0, 1, 1, 1, 0, 1, 1]),
torch.tensor([0, 1, 1, 0, 1, 1, 0, 0, 1]),
torch.ones([9], dtype=torch.bool),
torch.ones(1),
torch.ones(9),
torch.zeros([1, 9], dtype=torch.bool),
torch.zeros([1, 9], dtype=torch.bool),
torch.tensor(0.683772, dtype=torch.float),
),
],
)
def test_BinaryMCC(preds, targets, mask, weights, task_weights, lt_mask, gt_mask, expected_loss):
"""
Test on the BinaryMCC loss function for classification. Values have been checked using TorchMetrics.
"""
binary_mcc_loss = BinaryMCCLoss(task_weights)
loss = binary_mcc_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,expected_loss",
[
(
torch.tensor(
[[[0.16, 0.26, 0.58], [0.22, 0.61, 0.17]], [[0.71, 0.09, 0.20], [0.05, 0.82, 0.13]]]
),
torch.tensor([[2, 1], [0, 0]]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2, 2], dtype=torch.bool),
torch.zeros([2, 2], dtype=torch.bool),
torch.tensor(0.5, dtype=torch.float),
),
(
torch.tensor(
[[[0.16, 0.26, 0.58], [0.22, 0.61, 0.17]], [[0.71, 0.09, 0.20], [0.05, 0.82, 0.13]]]
),
torch.tensor([[2, 1], [0, 0]]),
torch.tensor([[1, 1], [0, 1]], dtype=torch.bool),
torch.ones([2]),
torch.ones([2]),
torch.zeros([2, 2], dtype=bool),
torch.zeros([2, 2], dtype=bool),
torch.tensor(1.0, dtype=torch.float),
),
],
)
def test_MulticlassMCC(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, expected_loss
):
"""
Test on the MulticlassMCC loss function for classification.
"""
multiclass_mcc_loss = MulticlassMCCLoss(task_weights)
loss = multiclass_mcc_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,threshold,expected_loss",
[
(
torch.tensor([[0.8, 0.2], [0.3, 0.7]]),
torch.tensor([[0.9, 0.1], [0.4, 0.6]]),
torch.ones([2, 2], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.ones([2], dtype=torch.bool),
torch.ones([2], dtype=torch.bool),
None,
torch.tensor(0.031319, dtype=torch.float),
),
(
torch.tensor([[0.6, 0.4], [0.2, 0.8]]),
torch.tensor([[0.7, 0.3], [0.3, 0.7]]),
torch.tensor([[1, 1], [1, 0]], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.ones([2], dtype=torch.bool),
torch.ones([2], dtype=torch.bool),
None,
torch.tensor(0.295655, dtype=torch.float),
),
(
torch.tensor([[0.6, 0.4], [0.2, 0.8]]),
torch.tensor([[0.7, 0.3], [0.3, 0.7]]),
torch.tensor([[1, 1], [1, 1]], dtype=torch.bool),
torch.ones([1]),
torch.ones([2]),
torch.ones([2], dtype=torch.bool),
torch.ones([2], dtype=torch.bool),
0.5,
torch.tensor(0.033673, dtype=torch.float),
),
],
)
def test_SID(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, threshold, expected_loss
):
"""
Test on the SID loss function. These values were not handchecked,
just checking function returns values with/without mask and threshold.
"""
sid_loss = SID(task_weights=task_weights, threshold=threshold)
loss = sid_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
@pytest.mark.parametrize(
"preds,targets,mask,weights,task_weights,lt_mask,gt_mask,threshold,expected_loss",
[
(
torch.tensor([[0.1, 0.3, 0.5, 0.7], [0.2, 0.4, 0.6, 0.8]]),
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]),
torch.tensor([[1, 1, 1, 1], [1, 0, 1, 0]], dtype=torch.bool),
torch.ones([2, 1]),
torch.ones([1, 4]),
torch.zeros([2, 4], dtype=torch.bool),
torch.zeros([2, 4], dtype=torch.bool),
None,
torch.tensor(0.1125, dtype=torch.float),
),
(
torch.tensor([[0.1, 0.3, 0.5, 0.7], [0.2, 0.4, 0.6, 0.8]]),
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]),
torch.ones([2, 4], dtype=torch.bool),
torch.ones([2, 1]),
torch.ones([1, 4]),
torch.zeros([2, 4], dtype=torch.bool),
torch.zeros([2, 4], dtype=torch.bool),
None,
torch.tensor(0.515625, dtype=torch.float),
),
(
torch.tensor([[0.1, 0.3, 0.5, 0.7], [0.2, 0.4, 0.6, 0.8]]),
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]),
torch.ones([2, 4], dtype=torch.bool),
torch.ones([2, 1]),
torch.ones([1, 4]),
torch.zeros([2, 4], dtype=torch.bool),
torch.zeros([2, 4], dtype=torch.bool),
0.3,
torch.tensor(0.501984, dtype=torch.float),
),
],
)
def test_Wasserstein(
preds, targets, mask, weights, task_weights, lt_mask, gt_mask, threshold, expected_loss
):
"""
Test on the Wasserstein loss function. These values were not handchecked,
just checking function returns values with/without mask and threshold.
"""
wasserstein_loss = Wasserstein(task_weights=task_weights, threshold=threshold)
loss = wasserstein_loss(preds, targets, mask, weights, lt_mask, gt_mask)
torch.testing.assert_close(loss, expected_loss)
# TODO: Add quantile loss tests