|
''' |
|
Author: Qiguang Chen |
|
Date: 2023-01-11 10:39:26 |
|
LastEditors: Qiguang Chen |
|
LastEditTime: 2023-01-31 18:22:36 |
|
Description: |
|
|
|
''' |
|
from torch import nn |
|
|
|
from common.utils import HiddenData, OutputData, InputData |
|
|
|
|
|
class BaseDecoder(nn.Module): |
|
"""Base class for all decoder module. |
|
|
|
Notice: t is often only necessary to change this module and its sub-modules |
|
""" |
|
def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None): |
|
super().__init__() |
|
self.intent_classifier = intent_classifier |
|
self.slot_classifier = slot_classifier |
|
self.interaction = interaction |
|
|
|
def forward(self, hidden: HiddenData): |
|
"""forward |
|
|
|
Args: |
|
hidden (HiddenData): encoded data |
|
|
|
Returns: |
|
OutputData: prediction logits |
|
""" |
|
if self.interaction is not None: |
|
hidden = self.interaction(hidden) |
|
intent = None |
|
slot = None |
|
if self.intent_classifier is not None: |
|
intent = self.intent_classifier(hidden) |
|
if self.slot_classifier is not None: |
|
slot = self.slot_classifier(hidden) |
|
return OutputData(intent, slot) |
|
|
|
def decode(self, output: OutputData, target: InputData = None): |
|
"""decode output logits |
|
|
|
Args: |
|
output (OutputData): output logits data |
|
target (InputData, optional): input data with attention mask. Defaults to None. |
|
|
|
Returns: |
|
List: decoded sequence ids |
|
""" |
|
intent, slot = None, None |
|
if self.intent_classifier is not None: |
|
intent = self.intent_classifier.decode(output, target) |
|
if self.slot_classifier is not None: |
|
slot = self.slot_classifier.decode(output, target) |
|
return OutputData(intent, slot) |
|
|
|
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True): |
|
"""compute loss. |
|
Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration. |
|
|
|
Args: |
|
pred (OutputData): output logits data |
|
target (InputData): input golden data |
|
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True. |
|
compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True. |
|
|
|
Returns: |
|
Tensor: loss result |
|
""" |
|
loss = 0 |
|
intent_loss = None |
|
slot_loss = None |
|
if self.intent_classifier is not None: |
|
intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None |
|
intent_weight = self.intent_classifier.config.get("weight") |
|
intent_weight = intent_weight if intent_weight is not None else 1. |
|
loss += intent_loss * intent_weight |
|
if self.slot_classifier is not None: |
|
slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None |
|
slot_weight = self.slot_classifier.config.get("weight") |
|
slot_weight = slot_weight if slot_weight is not None else 1. |
|
loss += slot_loss * slot_weight |
|
return loss, intent_loss, slot_loss |
|
|
|
|
|
class StackPropagationDecoder(BaseDecoder): |
|
|
|
def forward(self, hidden: HiddenData): |
|
|
|
pred_intent = self.intent_classifier(hidden) |
|
|
|
|
|
hidden = self.interaction(pred_intent, hidden) |
|
pred_slot = self.slot_classifier(hidden) |
|
return OutputData(pred_intent, pred_slot) |
|
|
|
class DCANetDecoder(BaseDecoder): |
|
|
|
def forward(self, hidden: HiddenData): |
|
if self.interaction is not None: |
|
hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier) |
|
return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden)) |
|
|
|
|