Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import math | |
import copy | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class MolTrans(nn.Module): | |
""" | |
Interaction Network with 2D interaction map | |
""" | |
def __init__( | |
self, | |
input_dim_drug: 23532, | |
input_dim_target: 16693, | |
max_drug_seq, | |
max_protein_seq, | |
emb_size: 384, | |
dropout_rate: 0.1, | |
# DenseNet | |
scale_down_ratio: 0.25, | |
growth_rate: 20, | |
transition_rate: 0.5, | |
num_dense_blocks: 4, | |
kernal_dense_size: 3, | |
# Encoder | |
intermediate_size: 1536, | |
num_attention_heads: 12, | |
attention_probs_dropout_prob: 0.1, | |
hidden_dropout_prob: 0.1, | |
# flatten_dim: 78192, | |
# batch_size | |
): | |
super().__init__() | |
self.max_d = max_drug_seq | |
self.max_p = max_protein_seq | |
self.emb_size = emb_size | |
self.dropout_rate = dropout_rate | |
# densenet | |
self.scale_down_ratio = scale_down_ratio | |
self.growth_rate = growth_rate | |
self.transition_rate = transition_rate | |
self.num_dense_blocks = num_dense_blocks | |
self.kernal_dense_size = kernal_dense_size | |
# self.batch_size = batch_size | |
self.input_dim_drug = input_dim_drug | |
self.input_dim_target = input_dim_target | |
self.n_layer = 2 | |
# encoder | |
self.hidden_size = emb_size | |
self.intermediate_size = intermediate_size | |
self.num_attention_heads = num_attention_heads | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.hidden_dropout_prob = hidden_dropout_prob | |
# self.flatten_dim = flatten_dim | |
# specialized embedding with positional one | |
self.demb = Embeddings(self.input_dim_drug, self.emb_size, self.max_d, self.dropout_rate) | |
self.pemb = Embeddings(self.input_dim_target, self.emb_size, self.max_p, self.dropout_rate) | |
self.d_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, | |
self.num_attention_heads, self.attention_probs_dropout_prob, | |
self.hidden_dropout_prob) | |
self.p_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, | |
self.num_attention_heads, self.attention_probs_dropout_prob, | |
self.hidden_dropout_prob) | |
self.icnn = nn.Conv2d(1, 3, 3, padding=0) | |
self.decoder = nn.Sequential( | |
# nn.Linear(self.flatten_dim, 512), | |
nn.LazyLinear(512), | |
nn.ReLU(True), | |
nn.BatchNorm1d(512), | |
nn.Linear(512, 64), | |
nn.ReLU(True), | |
nn.BatchNorm1d(64), | |
nn.Linear(64, 32), | |
nn.ReLU(True), | |
# # output layer | |
# nn.Linear(32, 1) | |
) | |
def forward(self, v_d, v_p): | |
d, d_mask = v_d | |
p, p_mask = v_p | |
ex_d_mask = d_mask.unsqueeze(1).unsqueeze(2) | |
ex_p_mask = p_mask.unsqueeze(1).unsqueeze(2) | |
ex_d_mask = (1.0 - ex_d_mask) * -10000.0 | |
ex_p_mask = (1.0 - ex_p_mask) * -10000.0 | |
d_emb = self.demb(d) # batch_size x seq_length x embed_size | |
p_emb = self.pemb(p) | |
batch_size = d_emb.size(0) | |
# set output_all_encoded_layers be false, to obtain the last layer hidden states only. | |
d_encoded_layers = self.d_encoder(d_emb.float(), ex_d_mask.float()) | |
# print(d_encoded_layers.shape) | |
p_encoded_layers = self.p_encoder(p_emb.float(), ex_p_mask.float()) | |
# print(p_encoded_layers.shape) | |
# repeat to have the same tensor size for aggregation | |
d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, self.max_p, 1) # repeat along protein size | |
p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, self.max_d, 1, 1) # repeat along drug size | |
i = d_aug * p_aug # interaction | |
i_v = i.view(batch_size, -1, self.max_d, self.max_p) | |
# batch_size x embed size x max_drug_seq_len x max_protein_seq_len | |
i_v = torch.sum(i_v, dim=1) | |
i_v = torch.unsqueeze(i_v, 1) | |
i_v = F.dropout(i_v, p=self.dropout_rate) | |
i = self.icnn(i_v).view(batch_size, -1) | |
score = self.decoder(i) | |
return score | |
class LayerNorm(nn.Module): | |
def __init__(self, hidden_size, variance_epsilon=1e-12): | |
super(LayerNorm, self).__init__() | |
self.gamma = nn.Parameter(torch.ones(hidden_size)) | |
self.beta = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = variance_epsilon | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.gamma * x + self.beta | |
class Embeddings(nn.Module): | |
"""Construct the embeddings from protein/target, position embeddings. | |
""" | |
def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): | |
super(Embeddings, self).__init__() | |
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |
self.position_embeddings = nn.Embedding(max_position_size, hidden_size) | |
self.LayerNorm = LayerNorm(hidden_size) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, input_ids): | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
words_embeddings = self.word_embeddings(input_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
embeddings = words_embeddings + position_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class SelfAttention(nn.Module): | |
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |
super(SelfAttention, self).__init__() | |
if hidden_size % num_attention_heads != 0: | |
raise ValueError( | |
"The hidden size (%d) is not a multiple of the number of attention " | |
"heads (%d)" % (hidden_size, num_attention_heads)) | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_size = int(hidden_size / num_attention_heads) | |
self.all_head_size = self.num_attention_heads * self.attention_head_size | |
self.query = nn.Linear(hidden_size, self.all_head_size) | |
self.key = nn.Linear(hidden_size, self.all_head_size) | |
self.value = nn.Linear(hidden_size, self.all_head_size) | |
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |
x = x.view(*new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward(self, hidden_states, attention_mask): | |
mixed_query_layer = self.query(hidden_states) | |
mixed_key_layer = self.key(hidden_states) | |
mixed_value_layer = self.value(hidden_states) | |
query_layer = self.transpose_for_scores(mixed_query_layer) | |
key_layer = self.transpose_for_scores(mixed_key_layer) | |
value_layer = self.transpose_for_scores(mixed_value_layer) | |
# Take the dot product between "query" and "key" to get the raw attention scores. | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
attention_scores = attention_scores + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout(attention_probs) | |
context_layer = torch.matmul(attention_probs, value_layer) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
return context_layer | |
class SelfOutput(nn.Module): | |
def __init__(self, hidden_size, hidden_dropout_prob): | |
super(SelfOutput, self).__init__() | |
self.dense = nn.Linear(hidden_size, hidden_size) | |
self.LayerNorm = LayerNorm(hidden_size) | |
self.dropout = nn.Dropout(hidden_dropout_prob) | |
def forward(self, hidden_states, input_tensor): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class Attention(nn.Module): | |
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |
super(Attention, self).__init__() | |
self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |
self.output = SelfOutput(hidden_size, hidden_dropout_prob) | |
def forward(self, input_tensor, attention_mask): | |
self_output = self.self(input_tensor, attention_mask) | |
attention_output = self.output(self_output, input_tensor) | |
return attention_output | |
class Intermediate(nn.Module): | |
def __init__(self, hidden_size, intermediate_size): | |
super(Intermediate, self).__init__() | |
self.dense = nn.Linear(hidden_size, intermediate_size) | |
def forward(self, hidden_states): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = F.relu(hidden_states) | |
return hidden_states | |
class Output(nn.Module): | |
def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): | |
super(Output, self).__init__() | |
self.dense = nn.Linear(intermediate_size, hidden_size) | |
self.LayerNorm = LayerNorm(hidden_size) | |
self.dropout = nn.Dropout(hidden_dropout_prob) | |
def forward(self, hidden_states, input_tensor): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class Encoder(nn.Module): | |
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
hidden_dropout_prob): | |
super(Encoder, self).__init__() | |
self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) | |
self.intermediate = Intermediate(hidden_size, intermediate_size) | |
self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) | |
def forward(self, hidden_states, attention_mask): | |
attention_output = self.attention(hidden_states, attention_mask) | |
intermediate_output = self.intermediate(attention_output) | |
layer_output = self.output(intermediate_output, attention_output) | |
return layer_output | |
class EncoderMultipleLayers(nn.Module): | |
def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
hidden_dropout_prob): | |
super().__init__() | |
layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
hidden_dropout_prob) | |
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) | |
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |
all_encoder_layers = [] | |
for layer_module in self.layer: | |
hidden_states = layer_module(hidden_states, attention_mask) | |
# if output_all_encoded_layers: | |
# all_encoder_layers.append(hidden_states) | |
# if not output_all_encoded_layers: | |
# all_encoder_layers.append(hidden_states) | |
return hidden_states | |