import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin ################################# # Latent Space Distance Metrics # ################################# class Cosine(nn.Module): def forward(self, x1, x2): return nn.CosineSimilarity()(x1, x2) class SquaredCosine(nn.Module): def forward(self, x1, x2): return nn.CosineSimilarity()(x1, x2) ** 2 class Euclidean(nn.Module): def forward(self, x1, x2): return torch.cdist(x1, x2, p=2.0) class SquaredEuclidean(nn.Module): def forward(self, x1, x2): return torch.cdist(x1, x2, p=2.0) ** 2 DISTANCE_METRICS = { "Cosine": Cosine, "SquaredCosine": SquaredCosine, "Euclidean": Euclidean, "SquaredEuclidean": SquaredEuclidean, } ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid} class ConPLex_DTI(nn.Module, PyTorchModelHubMixin): def __init__( self, drug_shape=2048, target_shape=1024, latent_dimension=1024, latent_activation="ReLU", latent_distance="Cosine", classify=True, ): super().__init__() self.drug_shape = drug_shape self.target_shape = target_shape self.latent_dimension = latent_dimension self.do_classify = classify self.latent_activation = ACTIVATIONS[latent_activation] self.drug_projector = nn.Sequential( nn.Linear(self.drug_shape, latent_dimension), self.latent_activation() ) nn.init.xavier_normal_(self.drug_projector[0].weight) self.target_projector = nn.Sequential( nn.Linear(self.target_shape, latent_dimension), self.latent_activation() ) nn.init.xavier_normal_(self.target_projector[0].weight) if self.do_classify: self.distance_metric = latent_distance self.activator = DISTANCE_METRICS[self.distance_metric]() def forward(self, drug, target): if self.do_classify: return self.classify(drug, target) else: return self.regress(drug, target) def regress(self, drug, target): drug_projection = self.drug_projector(drug) target_projection = self.target_projector(target) inner_prod = torch.bmm( drug_projection.view(-1, 1, self.latent_dimension), target_projection.view(-1, self.latent_dimension, 1), ).squeeze() return inner_prod.squeeze() def classify(self, drug, target): drug_projection = self.drug_projector(drug) target_projection = self.target_projector(target) distance = self.activator(drug_projection, target_projection) return distance.squeeze() if __name__ == "__main__": model_path = "./models/conplex_v1_bindingdb.pt" model = ConPLex_DTI() model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) model.save_pretrained("ConPLex_V1_BindingDB") model.push_to_hub("ConPLex_V1_BindingDB") model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB")