rna_torsionBERT / rna_torsionbert_model.py
sayby's picture
Upload model
aab9a40 verified
raw
history blame contribute delete
No virus
1.37 kB
from torch import nn
from transformers import PreTrainedModel, AutoModel, AutoConfig
from .rna_torsionbert_config import RNATorsionBertConfig
class RNATorsionBERTModel(PreTrainedModel):
config_class = RNATorsionBertConfig
def __init__(self, config):
super().__init__(config)
self.init_model(config.k)
self.dnabert = AutoModel.from_pretrained(
self.model_name, config=self.dnabert_config, trust_remote_code=True
)
self.regressor = nn.Sequential(
nn.LayerNorm(self.dnabert_config.hidden_size),
nn.Linear(self.dnabert_config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.num_classes),
)
self.activation = nn.Tanh()
def init_model(self, k: int):
model_name = f"zhihan1996/DNA_bert_{k}"
revisions = {3: "ed28178", 4: "c8499f0", 5: "c296157", 6: "a79a8fd"}
dnabert_config = AutoConfig.from_pretrained(
model_name,
revision=revisions[k],
trust_remote_code=True,
)
self.dnabert_config = dnabert_config
self.model_name = model_name
def forward(self, tensor):
z = self.dnabert(**tensor).last_hidden_state
output = self.regressor(z)
output = self.activation(output)
return {"logits": output}