libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
12.2 kB
import math
import copy
import torch
from torch import nn
import torch.nn.functional as F
class MolTrans(nn.Module):
"""
Interaction Network with 2D interaction map
"""
def __init__(
self,
input_dim_drug: 23532,
input_dim_target: 16693,
max_drug_seq,
max_protein_seq,
emb_size: 384,
dropout_rate: 0.1,
# DenseNet
scale_down_ratio: 0.25,
growth_rate: 20,
transition_rate: 0.5,
num_dense_blocks: 4,
kernal_dense_size: 3,
# Encoder
intermediate_size: 1536,
num_attention_heads: 12,
attention_probs_dropout_prob: 0.1,
hidden_dropout_prob: 0.1,
# flatten_dim: 78192,
# batch_size
):
super().__init__()
self.max_d = max_drug_seq
self.max_p = max_protein_seq
self.emb_size = emb_size
self.dropout_rate = dropout_rate
# densenet
self.scale_down_ratio = scale_down_ratio
self.growth_rate = growth_rate
self.transition_rate = transition_rate
self.num_dense_blocks = num_dense_blocks
self.kernal_dense_size = kernal_dense_size
# self.batch_size = batch_size
self.input_dim_drug = input_dim_drug
self.input_dim_target = input_dim_target
self.n_layer = 2
# encoder
self.hidden_size = emb_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_dropout_prob = hidden_dropout_prob
# self.flatten_dim = flatten_dim
# specialized embedding with positional one
self.demb = Embeddings(self.input_dim_drug, self.emb_size, self.max_d, self.dropout_rate)
self.pemb = Embeddings(self.input_dim_target, self.emb_size, self.max_p, self.dropout_rate)
self.d_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size,
self.num_attention_heads, self.attention_probs_dropout_prob,
self.hidden_dropout_prob)
self.p_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size,
self.num_attention_heads, self.attention_probs_dropout_prob,
self.hidden_dropout_prob)
self.icnn = nn.Conv2d(1, 3, 3, padding=0)
self.decoder = nn.Sequential(
# nn.Linear(self.flatten_dim, 512),
nn.LazyLinear(512),
nn.ReLU(True),
nn.BatchNorm1d(512),
nn.Linear(512, 64),
nn.ReLU(True),
nn.BatchNorm1d(64),
nn.Linear(64, 32),
nn.ReLU(True),
# # output layer
# nn.Linear(32, 1)
)
def forward(self, v_d, v_p):
d, d_mask = v_d
p, p_mask = v_p
ex_d_mask = d_mask.unsqueeze(1).unsqueeze(2)
ex_p_mask = p_mask.unsqueeze(1).unsqueeze(2)
ex_d_mask = (1.0 - ex_d_mask) * -10000.0
ex_p_mask = (1.0 - ex_p_mask) * -10000.0
d_emb = self.demb(d) # batch_size x seq_length x embed_size
p_emb = self.pemb(p)
batch_size = d_emb.size(0)
# set output_all_encoded_layers be false, to obtain the last layer hidden states only.
d_encoded_layers = self.d_encoder(d_emb.float(), ex_d_mask.float())
# print(d_encoded_layers.shape)
p_encoded_layers = self.p_encoder(p_emb.float(), ex_p_mask.float())
# print(p_encoded_layers.shape)
# repeat to have the same tensor size for aggregation
d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, self.max_p, 1) # repeat along protein size
p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, self.max_d, 1, 1) # repeat along drug size
i = d_aug * p_aug # interaction
i_v = i.view(batch_size, -1, self.max_d, self.max_p)
# batch_size x embed size x max_drug_seq_len x max_protein_seq_len
i_v = torch.sum(i_v, dim=1)
i_v = torch.unsqueeze(i_v, 1)
i_v = F.dropout(i_v, p=self.dropout_rate)
i = self.icnn(i_v).view(batch_size, -1)
score = self.decoder(i)
return score
class LayerNorm(nn.Module):
def __init__(self, hidden_size, variance_epsilon=1e-12):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = variance_epsilon
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
class Embeddings(nn.Module):
"""Construct the embeddings from protein/target, position embeddings.
"""
def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
super(Embeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class SelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class SelfOutput(nn.Module):
def __init__(self, hidden_size, hidden_dropout_prob):
super(SelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Attention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
super(Attention, self).__init__()
self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = SelfOutput(hidden_size, hidden_dropout_prob)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class Intermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super(Intermediate, self).__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = F.relu(hidden_states)
return hidden_states
class Output(nn.Module):
def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob):
super(Output, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Encoder(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super(Encoder, self).__init__()
self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
self.intermediate = Intermediate(hidden_size, intermediate_size)
self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class EncoderMultipleLayers(nn.Module):
def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super().__init__()
layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
return hidden_states