|
''' |
|
Author: Qiguang Chen |
|
Date: 2023-01-11 10:39:26 |
|
LastEditors: Qiguang Chen |
|
LastEditTime: 2023-01-26 17:25:17 |
|
Description: Base encoder and bi encoder |
|
|
|
''' |
|
from torch import nn |
|
|
|
from common.utils import InputData |
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
"""Base class for all encoder module |
|
""" |
|
def __init__(self, **config): |
|
super().__init__() |
|
self.config = config |
|
NotImplementedError("no implement") |
|
|
|
def forward(self, inputs: InputData): |
|
self.encoder(inputs.input_ids) |
|
|
|
|
|
class BiEncoder(nn.Module): |
|
"""Bi Encoder for encode intent and slot separately |
|
""" |
|
def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config): |
|
super().__init__() |
|
self.intent_encoder = intent_encoder |
|
self.slot_encoder = slot_encoder |
|
|
|
def forward(self, inputs: InputData): |
|
hidden_slot = self.slot_encoder(inputs) |
|
hidden_intent = self.intent_encoder(inputs) |
|
if not self.intent_encoder.config["return_sentence_level_hidden"]: |
|
hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state()) |
|
else: |
|
hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state()) |
|
return hidden_slot |
|
|