File size: 2,098 Bytes
6799faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold
import collections
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score, precision_score
import random
from sklearn.metrics import auc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
#import esm
from tqdm import tqdm
import time
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True






from transformers import PretrainedConfig
from typing import List


class TransHLA_I_Config(PretrainedConfig):
    model_type = "TransHLA"

    def __init__(
        self,
        max_len = 14,
        n_layers = 6,
        n_head = 8,
        d_model = 1280,
        d_ff = 64,
        cnn_padding_index = 0,
        cnn_num_channel = 256,
        region_embedding_size = 3,
        cnn_kernel_size = 3,
        cnn_padding_size = 1,
        cnn_stride = 1,
        pooling_size = 2,
        **kwargs,
    ):
        

        self.max_len = max_len
        self.n_layers = n_layers
        self.n_head = n_head
        self.d_model = d_model
        self.d_ff = d_ff
        self.cnn_padding_index = cnn_padding_index
        self.cnn_num_channel = cnn_num_channel
        self.region_embedding_size = region_embedding_size
        self.cnn_kernel_size= cnn_kernel_size
        self.cnn_padding_size = cnn_padding_size
        self.cnn_stride = cnn_stride
        self.pooling_size = pooling_size
        super().__init__(**kwargs)



resnet50d_config = TransHLA_I_Config()
resnet50d_config.save_pretrained("TransHLA_I")