File size: 1,451 Bytes
940d605 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from transformers import PretrainedConfig
from typing import List
class MXMConfig(PretrainedConfig):
model_type = "mxm"
def __init__(
self,
dim: int=128,
n_layer: int=6,
cutoff: float=5.0,
num_spherical: int=7,
num_radial: int=6,
envelope_exponent: int=5,
smiles: List[str] = None,
processor_class: str = "SmilesProcessor",
**kwargs,
):
self.dim = dim # the dimension of input feature
self.n_layer = n_layer # the number of GCN layers
self.cutoff = cutoff # the cutoff distance for neighbor searching
self.num_spherical = num_spherical # the number of spherical harmonics
self.num_radial = num_radial # the number of radial basis
self.envelope_exponent = envelope_exponent # the envelope exponent
self.smiles = smiles # process smiles
self.processor_class = processor_class
super().__init__(**kwargs)
if __name__ == "__main__":
mxm_config = MXMConfig(
dim=128,
n_layer=6,
cutoff=5.0,
num_spherical=7,
num_radial=6,
envelope_exponent=5,
smiles=["C", "CC", "CCC"],
processor_class="SmilesProcessor"
)
mxm_config.save_pretrained("custom-mxm") |