File size: 1,370 Bytes
2a44926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab9a40
2a44926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab9a40
2a44926
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch import nn
from transformers import PreTrainedModel, AutoModel, AutoConfig

from .rna_torsionbert_config import RNATorsionBertConfig


class RNATorsionBERTModel(PreTrainedModel):
    config_class = RNATorsionBertConfig

    def __init__(self, config):
        super().__init__(config)
        self.init_model(config.k)
        self.dnabert = AutoModel.from_pretrained(
            self.model_name, config=self.dnabert_config, trust_remote_code=True
        )
        self.regressor = nn.Sequential(
            nn.LayerNorm(self.dnabert_config.hidden_size),
            nn.Linear(self.dnabert_config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.num_classes),
        )
        self.activation = nn.Tanh()

    def init_model(self, k: int):
        model_name = f"zhihan1996/DNA_bert_{k}"
        revisions = {3: "ed28178", 4: "c8499f0", 5: "c296157", 6: "a79a8fd"}
        dnabert_config = AutoConfig.from_pretrained(
            model_name,
            revision=revisions[k],
            trust_remote_code=True,
        )
        self.dnabert_config = dnabert_config
        self.model_name = model_name

    def forward(self, tensor):
        z = self.dnabert(**tensor).last_hidden_state
        output = self.regressor(z)
        output = self.activation(output)
        return {"logits": output}