File size: 1,952 Bytes
3ea26d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import warnings

import pytest

from chemprop import models, nn
from chemprop.models import multi

warnings.filterwarnings("ignore", module=r"lightning.*", append=True)


@pytest.fixture(scope="session")
def mpnn(request):
    message_passing, agg = request.param
    ffn = nn.RegressionFFN()

    return models.MPNN(message_passing, agg, ffn, True)


@pytest.fixture(scope="session")
def regression_mpnn_mve(request):
    agg = nn.SumAggregation()
    ffn = nn.MveFFN()

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def regression_mpnn_evidential(request):
    agg = nn.SumAggregation()
    ffn = nn.EvidentialFFN()

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def classification_mpnn_dirichlet(request):
    agg = nn.SumAggregation()
    ffn = nn.BinaryDirichletFFN()

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def classification_mpnn(request):
    agg = nn.SumAggregation()
    ffn = nn.BinaryClassificationFFN()

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def classification_mpnn_multiclass(request):
    agg = nn.SumAggregation()
    ffn = nn.MulticlassClassificationFFN(n_classes=3)

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def classification_mpnn_multiclass_dirichlet(request):
    agg = nn.SumAggregation()
    ffn = nn.MulticlassDirichletFFN(n_classes=3)

    return models.MPNN(request.param, agg, ffn, True)


@pytest.fixture(scope="session")
def mcmpnn(request):
    blocks, n_components, shared = request.param
    mcmp = nn.MulticomponentMessagePassing(blocks, n_components, shared=shared)
    agg = nn.SumAggregation()
    ffn = nn.RegressionFFN(input_dim=mcmp.output_dim)

    return multi.MulticomponentMPNN(mcmp, agg, ffn, True)