custom-mxm / configuration_mxm.py
Huhujingjing's picture
Upload model
940d605
raw
history blame
1.45 kB
from transformers import PretrainedConfig
from typing import List
class MXMConfig(PretrainedConfig):
model_type = "mxm"
def __init__(
self,
dim: int=128,
n_layer: int=6,
cutoff: float=5.0,
num_spherical: int=7,
num_radial: int=6,
envelope_exponent: int=5,
smiles: List[str] = None,
processor_class: str = "SmilesProcessor",
**kwargs,
):
self.dim = dim # the dimension of input feature
self.n_layer = n_layer # the number of GCN layers
self.cutoff = cutoff # the cutoff distance for neighbor searching
self.num_spherical = num_spherical # the number of spherical harmonics
self.num_radial = num_radial # the number of radial basis
self.envelope_exponent = envelope_exponent # the envelope exponent
self.smiles = smiles # process smiles
self.processor_class = processor_class
super().__init__(**kwargs)
if __name__ == "__main__":
mxm_config = MXMConfig(
dim=128,
n_layer=6,
cutoff=5.0,
num_spherical=7,
num_radial=6,
envelope_exponent=5,
smiles=["C", "CC", "CCC"],
processor_class="SmilesProcessor"
)
mxm_config.save_pretrained("custom-mxm")