|
import math |
|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List |
|
|
|
dirname, _ = os.path.split(os.path.dirname(__file__)) |
|
|
|
|
|
@dataclass |
|
class GeneEmbeddModelConfig: |
|
|
|
model_input: str = "" |
|
|
|
num_embed_hidden: int = 100 |
|
ff_input_dim:int = 0 |
|
ff_hidden_dim: List = field(default_factory=lambda: [300]) |
|
feed_forward1_hidden: int = 256 |
|
num_attention_project: int = 64 |
|
num_encoder_layers: int = 1 |
|
dropout: float = 0.2 |
|
n: int = 121 |
|
relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11]) |
|
num_attention_heads: int = 5 |
|
|
|
window: int = 2 |
|
|
|
|
|
|
|
max_length: int = 40 |
|
tokens_len: int = math.ceil(max_length / window) |
|
second_input_token_len: int = 0 |
|
vocab_size: int = 0 |
|
second_input_vocab_size: int = 0 |
|
tokenizer: str = ( |
|
"overlap" |
|
) |
|
|
|
clf_target:str = 'sub_class_hico' |
|
num_classes: int = 0 |
|
class_mappings:List = field(default_factory=lambda: []) |
|
class_weights :List = field(default_factory=lambda: []) |
|
|
|
temperatures: List = field(default_factory=lambda: [0,10]) |
|
|
|
tokens_mapping_dict: Dict = None |
|
false_input_perc:float = 0.0 |
|
|
|
|
|
@dataclass |
|
class GeneEmbeddTrainConfig: |
|
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv' |
|
precursor_file_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/HBDxBase.csv' |
|
mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json' |
|
device: str = "cuda" |
|
l2_weight_decay: float = 0.05 |
|
batch_size: int = 512 |
|
|
|
batch_per_epoch:int = 0 |
|
|
|
label_smoothing_sim:float = 0.2 |
|
label_smoothing_clf:float = 0.0 |
|
|
|
learning_rate: float = 1e-3 |
|
lr_warmup_start: float = 0.1 |
|
lr_warmup_end: float = 1 |
|
|
|
|
|
warmup_epoch: int = 10 |
|
final_epoch: int = 20 |
|
|
|
top_k: int = 10 |
|
cross_val: bool = False |
|
labels_mapping_path: str = None |
|
filter_seq_length:bool = False |
|
|
|
num_augment_exp:int = 20 |
|
shuffle_exp: bool = False |
|
|
|
max_epochs: int = 3000 |
|
|
|
|
|
|