|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import esm |
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold |
|
import collections |
|
from torch.utils.data import DataLoader, TensorDataset |
|
import os |
|
from sklearn.metrics import roc_curve, roc_auc_score |
|
from sklearn.metrics import precision_recall_curve, average_precision_score |
|
from sklearn.metrics import matthews_corrcoef |
|
from sklearn.metrics import f1_score |
|
from sklearn.metrics import recall_score, precision_score |
|
import random |
|
from sklearn.metrics import auc |
|
from sklearn.decomposition import PCA |
|
import matplotlib.pyplot as plt |
|
|
|
from tqdm import tqdm |
|
import time |
|
import seaborn as sns |
|
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class TransHLA_II_Config(PretrainedConfig): |
|
model_type = "TransHLA" |
|
|
|
def __init__( |
|
self, |
|
max_len = 21, |
|
n_layers = 6, |
|
n_head = 8, |
|
d_model = 1280, |
|
d_ff = 64, |
|
cnn_padding_index = 0, |
|
cnn_num_channel = 256, |
|
region_embedding_size = 3, |
|
cnn_kernel_size = 3, |
|
cnn_padding_size = 1, |
|
cnn_stride = 1, |
|
pooling_size = 2, |
|
**kwargs, |
|
): |
|
|
|
|
|
self.max_len = max_len |
|
self.n_layers = n_layers |
|
self.n_head = n_head |
|
self.d_model = d_model |
|
self.d_ff = d_ff |
|
self.cnn_padding_index = cnn_padding_index |
|
self.cnn_num_channel = cnn_num_channel |
|
self.region_embedding_size = region_embedding_size |
|
self.cnn_kernel_size= cnn_kernel_size |
|
self.cnn_padding_size = cnn_padding_size |
|
self.cnn_stride = cnn_stride |
|
self.pooling_size = pooling_size |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
resnet50d_config = TransHLA_II_Config() |
|
resnet50d_config.save_pretrained("TransHLA_II") |
|
|
|
|
|
|
|
|