TransHLA_I / modeling_TransHLA_I.py
SkywalkerLu's picture
Upload model
6799faa verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold
import collections
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score, precision_score
import random
from sklearn.metrics import auc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
#import esm
from tqdm import tqdm
import time
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
from transformers import PretrainedConfig
from typing import List
from .configuration_TransHLA_I import TransHLA_I_Config
from transformers import PreTrainedModel
class TransHLA_I(nn.Module):
def __init__(self,config):
super(TransHLA_I, self).__init__()
max_len = config.max_len
n_layers = config.n_layers
n_head = config.n_head
d_model = config.d_model
d_ff = config.d_ff
cnn_padding_index = config.cnn_padding_index
cnn_num_channel = config.cnn_num_channel
region_embedding_size = config.region_embedding_size
cnn_kernel_size = config.cnn_kernel_size
cnn_padding_size = config.cnn_padding_size
cnn_stride = config.cnn_stride
pooling_size = config.pooling_size
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
self.region_cnn1 = nn.Conv1d(
d_model, cnn_num_channel, region_embedding_size)
self.region_cnn2 = nn.Conv1d(
max_len, cnn_num_channel, region_embedding_size)
self.padding1 = nn.ConstantPad1d((1, 1), 0)
self.padding2 = nn.ConstantPad1d((0, 1), 0)
self.relu = nn.ReLU()
self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
self.transformer_layers = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
self.transformer_encoder = nn.TransformerEncoder(
self.transformer_layers, num_layers=n_layers)
self.bn1 = nn.BatchNorm1d(d_model)
self.bn2 = nn.BatchNorm1d(cnn_num_channel)
self.bn3 = nn.BatchNorm1d(cnn_num_channel)
self.fc_task = nn.Sequential(
nn.Linear(d_model+2*cnn_num_channel, d_model // 4),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(d_model // 4, 64),
)
self.classifier = nn.Linear(64, 2)
def cnn_block1(self, x):
return self.cnn1(self.relu(x))
def cnn_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn1(x)
x = self.relu(x)
x = self.cnn1(x)
x = px + x
return x
def structure_block1(self, x):
return self.cnn2(self.relu(x))
def structure_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn2(x)
x = self.relu(x)
x = self.cnn2(x)
x = px + x
return x
def forward(self, x_in):
with torch.no_grad():
results = self.esm(x_in, repr_layers=[33], return_contacts=True)
emb = results["representations"][33]
structure_emb = results["contacts"]
output = self.transformer_encoder(emb)
representation = output[:, 0, :]
representation = self.bn1(representation)
cnn_emb = self.region_cnn1(emb.transpose(1, 2))
cnn_emb = self.padding1(cnn_emb)
conv = cnn_emb + self.cnn_block1(self.cnn_block1(cnn_emb))
while conv.size(-1) >= 2:
conv = self.cnn_block2(conv)
cnn_out = torch.squeeze(conv, dim=-1)
cnn_out = self.bn2(cnn_out)
structure_emb = self.region_cnn2(structure_emb.transpose(1, 2))
structure_emb = self.padding1(structure_emb)
structure_conv = structure_emb + \
self.structure_block1(self.structure_block1(structure_emb))
while structure_conv.size(-1) >= 2:
structure_conv = self.structure_block2(structure_conv)
structure_cnn_out = torch.squeeze(structure_conv, dim=-1)
structure_cnn_out = self.bn3(structure_cnn_out)
representation = torch.concat(
(representation,cnn_out,structure_cnn_out), dim=1)
reduction_feature = self.fc_task(representation)
reduction_feature = reduction_feature.view(
reduction_feature.size(0), -1)
logits_clsf = self.classifier(reduction_feature)
logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
return logits_clsf, reduction_feature
class TransHLA_I_Model(PreTrainedModel):
config_class = TransHLA_I_Config
def __init__(self, config):
super().__init__(config)
self.model = TransHLA_I(config)
def forward(self, tensor):
return self.model(tensor)