File size: 2,098 Bytes
6799faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold
import collections
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score, precision_score
import random
from sklearn.metrics import auc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
#import esm
from tqdm import tqdm
import time
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
from transformers import PretrainedConfig
from typing import List
class TransHLA_I_Config(PretrainedConfig):
model_type = "TransHLA"
def __init__(
self,
max_len = 14,
n_layers = 6,
n_head = 8,
d_model = 1280,
d_ff = 64,
cnn_padding_index = 0,
cnn_num_channel = 256,
region_embedding_size = 3,
cnn_kernel_size = 3,
cnn_padding_size = 1,
cnn_stride = 1,
pooling_size = 2,
**kwargs,
):
self.max_len = max_len
self.n_layers = n_layers
self.n_head = n_head
self.d_model = d_model
self.d_ff = d_ff
self.cnn_padding_index = cnn_padding_index
self.cnn_num_channel = cnn_num_channel
self.region_embedding_size = region_embedding_size
self.cnn_kernel_size= cnn_kernel_size
self.cnn_padding_size = cnn_padding_size
self.cnn_stride = cnn_stride
self.pooling_size = pooling_size
super().__init__(**kwargs)
resnet50d_config = TransHLA_I_Config()
resnet50d_config.save_pretrained("TransHLA_I")
|