import math import torch import torch.nn as nn from typing import Optional, Tuple, Union from dataclasses import dataclass from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from transformers.models.esm import EsmPreTrainedModel, EsmModel from transformers.models.bert import BertPreTrainedModel, BertModel from .configuration_protst import ProtSTConfig @dataclass class EsmProteinRepresentationOutput(ModelOutput): protein_feature: torch.FloatTensor = None residue_feature: torch.FloatTensor = None @dataclass class BertTextRepresentationOutput(ModelOutput): text_feature: torch.FloatTensor = None word_feature: torch.FloatTensor = None @dataclass class ProtSTClassificationOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None class ProtSTHead(nn.Module): def __init__(self, config, out_dim=512): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, out_dim) def forward(self, x): x = self.dense(x) x = nn.functional.relu(x) x = self.out_proj(x) return x class BertForPubMed(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.pad_token_id = config.pad_token_id self.cls_token_id = config.cls_token_id self.sep_token_id = config.sep_token_id self.bert = BertModel(config, add_pooling_layer=False) self.text_mlp = ProtSTHead(config) self.word_mlp = ProtSTHead(config) self.post_init() # NOTE def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], ModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) word_feature = outputs.last_hidden_state is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id) special_mask = (~is_special).to(torch.int64).unsqueeze(-1) pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype) pooled_feature = self.text_mlp(pooled_feature) word_feature = self.word_mlp(word_feature) if not return_dict: return (pooled_feature, word_feature) return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature) class EsmForProteinRepresentation(EsmPreTrainedModel): def __init__(self, config): super().__init__(config) self.cls_token_id = config.cls_token_id self.pad_token_id = config.pad_token_id self.eos_token_id = config.eos_token_id self.esm = EsmModel(config, add_pooling_layer=False) self.post_init() # NOTE def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, EsmProteinRepresentationOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim] # mean readout is_special = ( (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id) ) special_mask = (~is_special).to(torch.int64).unsqueeze(-1) protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype) return EsmProteinRepresentationOutput( protein_feature=protein_feature, residue_feature=residue_feature ) class ProtSTPreTrainedModel(PreTrainedModel): config_class = ProtSTConfig class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.protein_model = EsmForProteinRepresentation(config.protein_config) self.classifier = ProtSTHead(config.protein_config, out_dim=config.num_labels) self.post_init() # NOTE def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, ProtSTClassificationOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Returns: Examples: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.protein_model( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels] loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() labels = labels.to(logits.device) loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return ProtSTClassificationOutput(loss=loss, logits=logits)