Spaces:
Build error
Build error
File size: 1,952 Bytes
3ea26d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import warnings
import pytest
from chemprop import models, nn
from chemprop.models import multi
warnings.filterwarnings("ignore", module=r"lightning.*", append=True)
@pytest.fixture(scope="session")
def mpnn(request):
message_passing, agg = request.param
ffn = nn.RegressionFFN()
return models.MPNN(message_passing, agg, ffn, True)
@pytest.fixture(scope="session")
def regression_mpnn_mve(request):
agg = nn.SumAggregation()
ffn = nn.MveFFN()
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def regression_mpnn_evidential(request):
agg = nn.SumAggregation()
ffn = nn.EvidentialFFN()
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def classification_mpnn_dirichlet(request):
agg = nn.SumAggregation()
ffn = nn.BinaryDirichletFFN()
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def classification_mpnn(request):
agg = nn.SumAggregation()
ffn = nn.BinaryClassificationFFN()
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def classification_mpnn_multiclass(request):
agg = nn.SumAggregation()
ffn = nn.MulticlassClassificationFFN(n_classes=3)
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def classification_mpnn_multiclass_dirichlet(request):
agg = nn.SumAggregation()
ffn = nn.MulticlassDirichletFFN(n_classes=3)
return models.MPNN(request.param, agg, ffn, True)
@pytest.fixture(scope="session")
def mcmpnn(request):
blocks, n_components, shared = request.param
mcmp = nn.MulticomponentMessagePassing(blocks, n_components, shared=shared)
agg = nn.SumAggregation()
ffn = nn.RegressionFFN(input_dim=mcmp.output_dim)
return multi.MulticomponentMPNN(mcmp, agg, ffn, True)
|