|
''' |
|
Author: Qiguang Chen |
|
Date: 2023-01-11 10:39:26 |
|
LastEditors: Qiguang Chen |
|
LastEditTime: 2023-02-18 17:38:30 |
|
Description: pretrained encoder model |
|
|
|
''' |
|
from transformers import AutoModel, AutoConfig |
|
from common import utils |
|
|
|
from common.utils import InputData, HiddenData |
|
from model.encoder.base_encoder import BaseEncoder |
|
|
|
|
|
class PretrainedEncoder(BaseEncoder): |
|
def __init__(self, **config): |
|
""" init pretrained encoder |
|
|
|
Args: |
|
config (dict): |
|
encoder_name (str): pretrained model name in hugging face. |
|
""" |
|
super().__init__(**config) |
|
if self.config.get("_is_check_point_"): |
|
self.encoder = utils.instantiate(config["pretrained_model"], target="_pretrained_model_target_") |
|
|
|
else: |
|
self.encoder = AutoModel.from_pretrained(config["encoder_name"]) |
|
|
|
def forward(self, inputs: InputData): |
|
output = self.encoder(**inputs.get_inputs()) |
|
hidden = HiddenData(None, output.last_hidden_state) |
|
if self.config.get("return_with_input"): |
|
hidden.add_input(inputs) |
|
if self.config.get("return_sentence_level_hidden"): |
|
padding_side = self.config.get("padding_side") |
|
if hasattr(output, "pooler_output"): |
|
hidden.update_intent_hidden_state(output.pooler_output) |
|
elif padding_side is not None and padding_side == "left": |
|
hidden.update_intent_hidden_state(output.last_hidden_state[:, -1]) |
|
else: |
|
hidden.update_intent_hidden_state(output.last_hidden_state[:, 0]) |
|
return hidden |
|
|