Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
################################# | |
# Latent Space Distance Metrics # | |
################################# | |
class Cosine(nn.Module): | |
def forward(self, x1, x2): | |
return nn.CosineSimilarity()(x1, x2) | |
class SquaredCosine(nn.Module): | |
def forward(self, x1, x2): | |
return nn.CosineSimilarity()(x1, x2) ** 2 | |
class Euclidean(nn.Module): | |
def forward(self, x1, x2): | |
return torch.cdist(x1, x2, p=2.0) | |
class SquaredEuclidean(nn.Module): | |
def forward(self, x1, x2): | |
return torch.cdist(x1, x2, p=2.0) ** 2 | |
DISTANCE_METRICS = { | |
"Cosine": Cosine, | |
"SquaredCosine": SquaredCosine, | |
"Euclidean": Euclidean, | |
"SquaredEuclidean": SquaredEuclidean, | |
} | |
ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid} | |
class ConPLex_DTI(nn.Module, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
drug_shape=2048, | |
target_shape=1024, | |
latent_dimension=1024, | |
latent_activation="ReLU", | |
latent_distance="Cosine", | |
classify=True, | |
): | |
super().__init__() | |
self.drug_shape = drug_shape | |
self.target_shape = target_shape | |
self.latent_dimension = latent_dimension | |
self.do_classify = classify | |
self.latent_activation = ACTIVATIONS[latent_activation] | |
self.drug_projector = nn.Sequential( | |
nn.Linear(self.drug_shape, latent_dimension), self.latent_activation() | |
) | |
nn.init.xavier_normal_(self.drug_projector[0].weight) | |
self.target_projector = nn.Sequential( | |
nn.Linear(self.target_shape, latent_dimension), self.latent_activation() | |
) | |
nn.init.xavier_normal_(self.target_projector[0].weight) | |
if self.do_classify: | |
self.distance_metric = latent_distance | |
self.activator = DISTANCE_METRICS[self.distance_metric]() | |
def forward(self, drug, target): | |
if self.do_classify: | |
return self.classify(drug, target) | |
else: | |
return self.regress(drug, target) | |
def regress(self, drug, target): | |
drug_projection = self.drug_projector(drug) | |
target_projection = self.target_projector(target) | |
inner_prod = torch.bmm( | |
drug_projection.view(-1, 1, self.latent_dimension), | |
target_projection.view(-1, self.latent_dimension, 1), | |
).squeeze() | |
return inner_prod.squeeze() | |
def classify(self, drug, target): | |
drug_projection = self.drug_projector(drug) | |
target_projection = self.target_projector(target) | |
distance = self.activator(drug_projection, target_projection) | |
return distance.squeeze() | |
if __name__ == "__main__": | |
model_path = "./models/conplex_v1_bindingdb.pt" | |
model = ConPLex_DTI() | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.save_pretrained("ConPLex_V1_BindingDB") | |
model.push_to_hub("ConPLex_V1_BindingDB") | |
model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB") |