Huhujingjing commited on
Commit
3e2d1c7
1 Parent(s): 940d605

Upload model

Browse files
Files changed (1) hide show
  1. modeling_mxm.py +6 -6
modeling_mxm.py CHANGED
@@ -949,7 +949,7 @@ class MXMModel(PreTrainedModel):
949
  def __init__(self, config):
950
  super().__init__(config)
951
 
952
- self.model = MXMNet(
953
  dim=config.dim,
954
  n_layer=config.n_layer,
955
  cutoff=config.cutoff,
@@ -961,14 +961,14 @@ class MXMModel(PreTrainedModel):
961
  smiles=config.smiles,
962
  )
963
 
964
- self.mxm_model = None
965
  self.dataset = None
966
  self.output = None
967
  self.data_loader = None
968
  self.pred_data = None
969
 
970
  def forward(self, tensor):
971
- return self.model.forward_features(tensor)
972
 
973
  def SmilesProcessor(self, smiles):
974
  return self.process.get_data(smiles)
@@ -982,8 +982,8 @@ class MXMModel(PreTrainedModel):
982
  drop_last = kwargs.pop('drop_last', False)
983
  num_workers = kwargs.pop('num_workers', 0)
984
 
985
- self.mxm_model = AutoModel.from_pretrained("Huhujingjing/custom-mxm", trust_remote_code=True).to(device)
986
- self.mxm_model.eval()
987
 
988
  self.dataset = self.process.get_data(smiles)
989
  self.output = ""
@@ -1004,7 +1004,7 @@ class MXMModel(PreTrainedModel):
1004
  batch = batch.to(device)
1005
  with torch.no_grad():
1006
  self.pred_data['smiles'] += batch['smiles']
1007
- self.pred_data['pred'] += self.gcn_model(batch).cpu().tolist()
1008
 
1009
  pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1010
  if device == 'cuda':
 
949
  def __init__(self, config):
950
  super().__init__(config)
951
 
952
+ self.backbone = MXMNet(
953
  dim=config.dim,
954
  n_layer=config.n_layer,
955
  cutoff=config.cutoff,
 
961
  smiles=config.smiles,
962
  )
963
 
964
+ self.model = None
965
  self.dataset = None
966
  self.output = None
967
  self.data_loader = None
968
  self.pred_data = None
969
 
970
  def forward(self, tensor):
971
+ return self.backbone.forward_features(tensor)
972
 
973
  def SmilesProcessor(self, smiles):
974
  return self.process.get_data(smiles)
 
982
  drop_last = kwargs.pop('drop_last', False)
983
  num_workers = kwargs.pop('num_workers', 0)
984
 
985
+ self.model = AutoModel.from_pretrained("Huhujingjing/custom-mxm", trust_remote_code=True).to(device)
986
+ self.model.eval()
987
 
988
  self.dataset = self.process.get_data(smiles)
989
  self.output = ""
 
1004
  batch = batch.to(device)
1005
  with torch.no_grad():
1006
  self.pred_data['smiles'] += batch['smiles']
1007
+ self.pred_data['pred'] += self.model(batch).cpu().tolist()
1008
 
1009
  pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1010
  if device == 'cuda':