Spaces:
Sleeping
Sleeping
File size: 3,093 Bytes
dfffe94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
#################################
# Latent Space Distance Metrics #
#################################
class Cosine(nn.Module):
def forward(self, x1, x2):
return nn.CosineSimilarity()(x1, x2)
class SquaredCosine(nn.Module):
def forward(self, x1, x2):
return nn.CosineSimilarity()(x1, x2) ** 2
class Euclidean(nn.Module):
def forward(self, x1, x2):
return torch.cdist(x1, x2, p=2.0)
class SquaredEuclidean(nn.Module):
def forward(self, x1, x2):
return torch.cdist(x1, x2, p=2.0) ** 2
DISTANCE_METRICS = {
"Cosine": Cosine,
"SquaredCosine": SquaredCosine,
"Euclidean": Euclidean,
"SquaredEuclidean": SquaredEuclidean,
}
ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}
class ConPLex_DTI(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
drug_shape=2048,
target_shape=1024,
latent_dimension=1024,
latent_activation="ReLU",
latent_distance="Cosine",
classify=True,
):
super().__init__()
self.drug_shape = drug_shape
self.target_shape = target_shape
self.latent_dimension = latent_dimension
self.do_classify = classify
self.latent_activation = ACTIVATIONS[latent_activation]
self.drug_projector = nn.Sequential(
nn.Linear(self.drug_shape, latent_dimension), self.latent_activation()
)
nn.init.xavier_normal_(self.drug_projector[0].weight)
self.target_projector = nn.Sequential(
nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
)
nn.init.xavier_normal_(self.target_projector[0].weight)
if self.do_classify:
self.distance_metric = latent_distance
self.activator = DISTANCE_METRICS[self.distance_metric]()
def forward(self, drug, target):
if self.do_classify:
return self.classify(drug, target)
else:
return self.regress(drug, target)
def regress(self, drug, target):
drug_projection = self.drug_projector(drug)
target_projection = self.target_projector(target)
inner_prod = torch.bmm(
drug_projection.view(-1, 1, self.latent_dimension),
target_projection.view(-1, self.latent_dimension, 1),
).squeeze()
return inner_prod.squeeze()
def classify(self, drug, target):
drug_projection = self.drug_projector(drug)
target_projection = self.target_projector(target)
distance = self.activator(drug_projection, target_projection)
return distance.squeeze()
if __name__ == "__main__":
model_path = "./models/conplex_v1_bindingdb.pt"
model = ConPLex_DTI()
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.save_pretrained("ConPLex_V1_BindingDB")
model.push_to_hub("ConPLex_V1_BindingDB")
model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB") |