from torch import nn from transformers import PreTrainedModel, AutoModel, AutoConfig from .rna_torsionbert_config import RNATorsionBertConfig class RNATorsionBERTModel(PreTrainedModel): config_class = RNATorsionBertConfig def __init__(self, config): super().__init__(config) self.init_model(config.k) self.dnabert = AutoModel.from_pretrained( self.model_name, config=self.dnabert_config, trust_remote_code=True ) self.regressor = nn.Sequential( nn.LayerNorm(self.dnabert_config.hidden_size), nn.Linear(self.dnabert_config.hidden_size, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.num_classes), ) self.activation = nn.Tanh() def init_model(self, k: int): model_name = f"zhihan1996/DNA_bert_{k}" revisions = {3: "ed28178", 4: "c8499f0", 5: "c296157", 6: "a79a8fd"} dnabert_config = AutoConfig.from_pretrained( model_name, revision=revisions[k], trust_remote_code=True, ) self.dnabert_config = dnabert_config self.model_name = model_name def forward(self, tensor): z = self.dnabert(**tensor).last_hidden_state output = self.regressor(z) output = self.activation(output) return {"logits": output}