from transformers import PretrainedConfig from typing import List class MXMConfig(PretrainedConfig): model_type = "mxm" def __init__( self, dim: int=128, n_layer: int=6, cutoff: float=5.0, num_spherical: int=7, num_radial: int=6, envelope_exponent: int=5, smiles: List[str] = None, processor_class: str = "SmilesProcessor", **kwargs, ): self.dim = dim # the dimension of input feature self.n_layer = n_layer # the number of GCN layers self.cutoff = cutoff # the cutoff distance for neighbor searching self.num_spherical = num_spherical # the number of spherical harmonics self.num_radial = num_radial # the number of radial basis self.envelope_exponent = envelope_exponent # the envelope exponent self.smiles = smiles # process smiles self.processor_class = processor_class super().__init__(**kwargs) if __name__ == "__main__": mxm_config = MXMConfig( dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor" ) mxm_config.save_pretrained("custom-mxm")