import comet from typing import Dict import torch from comet.encoders.base import Encoder from comet.encoders.bert import BERTEncoder from transformers import AutoModel, AutoTokenizer class robertaEncoder(BERTEncoder): def __init__(self, pretrained_model: str) -> None: super(Encoder, self).__init__() self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model) self.model = AutoModel.from_pretrained( pretrained_model, add_pooling_layer=False ) self.model.encoder.output_hidden_states = True @classmethod def from_pretrained(cls, pretrained_model: str) -> Encoder: return robertaEncoder(pretrained_model) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: last_hidden_states, _, all_layers = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=False, ) return { "sentemb": last_hidden_states[:, 0, :], "wordemb": last_hidden_states, "all_layers": all_layers, "attention_mask": attention_mask, } # initialize roberta into str2encoder comet.encoders.str2encoder['RoBERTa'] = robertaEncoder