File size: 5,655 Bytes
6799faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold
import collections
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score, precision_score
import random
from sklearn.metrics import auc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
#import esm
from tqdm import tqdm
import time
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

from transformers import PretrainedConfig
from typing import List
from .configuration_TransHLA_I import TransHLA_I_Config
    
from transformers import PreTrainedModel



class TransHLA_I(nn.Module):
    def __init__(self,config):
        super(TransHLA_I, self).__init__()

        max_len = config.max_len
        n_layers = config.n_layers
        n_head = config.n_head
        d_model = config.d_model
        d_ff = config.d_ff
        cnn_padding_index = config.cnn_padding_index
        cnn_num_channel = config.cnn_num_channel
        region_embedding_size = config.region_embedding_size
        cnn_kernel_size = config.cnn_kernel_size
        cnn_padding_size = config.cnn_padding_size
        cnn_stride = config.cnn_stride
        pooling_size = config.pooling_size

        self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.region_cnn1 = nn.Conv1d(
            d_model, cnn_num_channel, region_embedding_size)
        self.region_cnn2 = nn.Conv1d(
            max_len, cnn_num_channel, region_embedding_size)
        self.padding1 = nn.ConstantPad1d((1, 1), 0)
        self.padding2 = nn.ConstantPad1d((0, 1), 0)
        self.relu = nn.ReLU()
        self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
                              padding=cnn_padding_size, stride=cnn_stride)
        self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
                              padding=cnn_padding_size, stride=cnn_stride)
        self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
        self.transformer_layers = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_layers, num_layers=n_layers)
        self.bn1 = nn.BatchNorm1d(d_model)
        self.bn2 = nn.BatchNorm1d(cnn_num_channel)
        self.bn3 = nn.BatchNorm1d(cnn_num_channel)
        self.fc_task = nn.Sequential(
            nn.Linear(d_model+2*cnn_num_channel, d_model // 4),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(d_model // 4, 64),
        )
        self.classifier = nn.Linear(64, 2)

    def cnn_block1(self, x):
        return self.cnn1(self.relu(x))

    def cnn_block2(self, x):
        x = self.padding2(x)
        px = self.maxpooling(x)
        x = self.relu(px)
        x = self.cnn1(x)
        x = self.relu(x)
        x = self.cnn1(x)
        x = px + x
        return x

    def structure_block1(self, x):
        return self.cnn2(self.relu(x))

    def structure_block2(self, x):
        x = self.padding2(x)
        px = self.maxpooling(x)
        x = self.relu(px)
        x = self.cnn2(x)
        x = self.relu(x)
        x = self.cnn2(x)
        x = px + x
        return x

    def forward(self, x_in):
        with torch.no_grad():
            results = self.esm(x_in, repr_layers=[33], return_contacts=True)
            emb = results["representations"][33]
            structure_emb = results["contacts"]
        output = self.transformer_encoder(emb)
        representation = output[:, 0, :]
        representation = self.bn1(representation)
        cnn_emb = self.region_cnn1(emb.transpose(1, 2))
        cnn_emb = self.padding1(cnn_emb)
        conv = cnn_emb + self.cnn_block1(self.cnn_block1(cnn_emb))
        while conv.size(-1) >= 2:
            conv = self.cnn_block2(conv)
        cnn_out = torch.squeeze(conv, dim=-1)
        cnn_out = self.bn2(cnn_out)

        structure_emb = self.region_cnn2(structure_emb.transpose(1, 2))
        structure_emb = self.padding1(structure_emb)
        structure_conv = structure_emb + \
            self.structure_block1(self.structure_block1(structure_emb))
        while structure_conv.size(-1) >= 2:
            structure_conv = self.structure_block2(structure_conv)
        structure_cnn_out = torch.squeeze(structure_conv, dim=-1)
        structure_cnn_out = self.bn3(structure_cnn_out)
        representation = torch.concat(
            (representation,cnn_out,structure_cnn_out), dim=1)
        reduction_feature = self.fc_task(representation)
        reduction_feature = reduction_feature.view(
            reduction_feature.size(0), -1)
        logits_clsf = self.classifier(reduction_feature)
        logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
        return logits_clsf, reduction_feature

    
class TransHLA_I_Model(PreTrainedModel):
    config_class = TransHLA_I_Config

    def __init__(self, config):
        super().__init__(config)
        self.model = TransHLA_I(config)

    def forward(self, tensor):
        return self.model(tensor)