Huhujingjing
commited on
Commit
•
3e2d1c7
1
Parent(s):
940d605
Upload model
Browse files- modeling_mxm.py +6 -6
modeling_mxm.py
CHANGED
@@ -949,7 +949,7 @@ class MXMModel(PreTrainedModel):
|
|
949 |
def __init__(self, config):
|
950 |
super().__init__(config)
|
951 |
|
952 |
-
self.
|
953 |
dim=config.dim,
|
954 |
n_layer=config.n_layer,
|
955 |
cutoff=config.cutoff,
|
@@ -961,14 +961,14 @@ class MXMModel(PreTrainedModel):
|
|
961 |
smiles=config.smiles,
|
962 |
)
|
963 |
|
964 |
-
self.
|
965 |
self.dataset = None
|
966 |
self.output = None
|
967 |
self.data_loader = None
|
968 |
self.pred_data = None
|
969 |
|
970 |
def forward(self, tensor):
|
971 |
-
return self.
|
972 |
|
973 |
def SmilesProcessor(self, smiles):
|
974 |
return self.process.get_data(smiles)
|
@@ -982,8 +982,8 @@ class MXMModel(PreTrainedModel):
|
|
982 |
drop_last = kwargs.pop('drop_last', False)
|
983 |
num_workers = kwargs.pop('num_workers', 0)
|
984 |
|
985 |
-
self.
|
986 |
-
self.
|
987 |
|
988 |
self.dataset = self.process.get_data(smiles)
|
989 |
self.output = ""
|
@@ -1004,7 +1004,7 @@ class MXMModel(PreTrainedModel):
|
|
1004 |
batch = batch.to(device)
|
1005 |
with torch.no_grad():
|
1006 |
self.pred_data['smiles'] += batch['smiles']
|
1007 |
-
self.pred_data['pred'] += self.
|
1008 |
|
1009 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
1010 |
if device == 'cuda':
|
|
|
949 |
def __init__(self, config):
|
950 |
super().__init__(config)
|
951 |
|
952 |
+
self.backbone = MXMNet(
|
953 |
dim=config.dim,
|
954 |
n_layer=config.n_layer,
|
955 |
cutoff=config.cutoff,
|
|
|
961 |
smiles=config.smiles,
|
962 |
)
|
963 |
|
964 |
+
self.model = None
|
965 |
self.dataset = None
|
966 |
self.output = None
|
967 |
self.data_loader = None
|
968 |
self.pred_data = None
|
969 |
|
970 |
def forward(self, tensor):
|
971 |
+
return self.backbone.forward_features(tensor)
|
972 |
|
973 |
def SmilesProcessor(self, smiles):
|
974 |
return self.process.get_data(smiles)
|
|
|
982 |
drop_last = kwargs.pop('drop_last', False)
|
983 |
num_workers = kwargs.pop('num_workers', 0)
|
984 |
|
985 |
+
self.model = AutoModel.from_pretrained("Huhujingjing/custom-mxm", trust_remote_code=True).to(device)
|
986 |
+
self.model.eval()
|
987 |
|
988 |
self.dataset = self.process.get_data(smiles)
|
989 |
self.output = ""
|
|
|
1004 |
batch = batch.to(device)
|
1005 |
with torch.no_grad():
|
1006 |
self.pred_data['smiles'] += batch['smiles']
|
1007 |
+
self.pred_data['pred'] += self.model(batch).cpu().tolist()
|
1008 |
|
1009 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
1010 |
if device == 'cuda':
|