| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-01-26 17:25:17 | |
| Description: Base encoder and bi encoder | |
| ''' | |
| from torch import nn | |
| from common.utils import InputData | |
| class BaseEncoder(nn.Module): | |
| """Base class for all encoder module | |
| """ | |
| def __init__(self, **config): | |
| super().__init__() | |
| self.config = config | |
| NotImplementedError("no implement") | |
| def forward(self, inputs: InputData): | |
| self.encoder(inputs.input_ids) | |
| class BiEncoder(nn.Module): | |
| """Bi Encoder for encode intent and slot separately | |
| """ | |
| def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config): | |
| super().__init__() | |
| self.intent_encoder = intent_encoder | |
| self.slot_encoder = slot_encoder | |
| def forward(self, inputs: InputData): | |
| hidden_slot = self.slot_encoder(inputs) | |
| hidden_intent = self.intent_encoder(inputs) | |
| if not self.intent_encoder.config["return_sentence_level_hidden"]: | |
| hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state()) | |
| else: | |
| hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state()) | |
| return hidden_slot | |