import torch import torch.nn as nn import torch.nn.functional as F import esm import numpy as np import pandas as pd from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold import collections from torch.utils.data import DataLoader, TensorDataset import os from sklearn.metrics import roc_curve, roc_auc_score from sklearn.metrics import precision_recall_curve, average_precision_score from sklearn.metrics import matthews_corrcoef from sklearn.metrics import f1_score from sklearn.metrics import recall_score, precision_score import random from sklearn.metrics import auc from sklearn.decomposition import PCA import matplotlib.pyplot as plt #import esm from tqdm import tqdm import time import seaborn as sns from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True from transformers import PretrainedConfig from typing import List from .configuration_TransHLA_I import TransHLA_I_Config from transformers import PreTrainedModel class TransHLA_I(nn.Module): def __init__(self,config): super(TransHLA_I, self).__init__() max_len = config.max_len n_layers = config.n_layers n_head = config.n_head d_model = config.d_model d_ff = config.d_ff cnn_padding_index = config.cnn_padding_index cnn_num_channel = config.cnn_num_channel region_embedding_size = config.region_embedding_size cnn_kernel_size = config.cnn_kernel_size cnn_padding_size = config.cnn_padding_size cnn_stride = config.cnn_stride pooling_size = config.pooling_size self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() self.region_cnn1 = nn.Conv1d( d_model, cnn_num_channel, region_embedding_size) self.region_cnn2 = nn.Conv1d( max_len, cnn_num_channel, region_embedding_size) self.padding1 = nn.ConstantPad1d((1, 1), 0) self.padding2 = nn.ConstantPad1d((0, 1), 0) self.relu = nn.ReLU() self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride) self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride) self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size) self.transformer_layers = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2) self.transformer_encoder = nn.TransformerEncoder( self.transformer_layers, num_layers=n_layers) self.bn1 = nn.BatchNorm1d(d_model) self.bn2 = nn.BatchNorm1d(cnn_num_channel) self.bn3 = nn.BatchNorm1d(cnn_num_channel) self.fc_task = nn.Sequential( nn.Linear(d_model+2*cnn_num_channel, d_model // 4), nn.Dropout(0.3), nn.ReLU(), nn.Linear(d_model // 4, 64), ) self.classifier = nn.Linear(64, 2) def cnn_block1(self, x): return self.cnn1(self.relu(x)) def cnn_block2(self, x): x = self.padding2(x) px = self.maxpooling(x) x = self.relu(px) x = self.cnn1(x) x = self.relu(x) x = self.cnn1(x) x = px + x return x def structure_block1(self, x): return self.cnn2(self.relu(x)) def structure_block2(self, x): x = self.padding2(x) px = self.maxpooling(x) x = self.relu(px) x = self.cnn2(x) x = self.relu(x) x = self.cnn2(x) x = px + x return x def forward(self, x_in): with torch.no_grad(): results = self.esm(x_in, repr_layers=[33], return_contacts=True) emb = results["representations"][33] structure_emb = results["contacts"] output = self.transformer_encoder(emb) representation = output[:, 0, :] representation = self.bn1(representation) cnn_emb = self.region_cnn1(emb.transpose(1, 2)) cnn_emb = self.padding1(cnn_emb) conv = cnn_emb + self.cnn_block1(self.cnn_block1(cnn_emb)) while conv.size(-1) >= 2: conv = self.cnn_block2(conv) cnn_out = torch.squeeze(conv, dim=-1) cnn_out = self.bn2(cnn_out) structure_emb = self.region_cnn2(structure_emb.transpose(1, 2)) structure_emb = self.padding1(structure_emb) structure_conv = structure_emb + \ self.structure_block1(self.structure_block1(structure_emb)) while structure_conv.size(-1) >= 2: structure_conv = self.structure_block2(structure_conv) structure_cnn_out = torch.squeeze(structure_conv, dim=-1) structure_cnn_out = self.bn3(structure_cnn_out) representation = torch.concat( (representation,cnn_out,structure_cnn_out), dim=1) reduction_feature = self.fc_task(representation) reduction_feature = reduction_feature.view( reduction_feature.size(0), -1) logits_clsf = self.classifier(reduction_feature) logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1) return logits_clsf, reduction_feature class TransHLA_I_Model(PreTrainedModel): config_class = TransHLA_I_Config def __init__(self, config): super().__init__(config) self.model = TransHLA_I(config) def forward(self, tensor): return self.model(tensor)