''' Author: Qiguang Chen Date: 2023-01-11 10:39:26 LastEditors: Qiguang Chen LastEditTime: 2023-01-26 17:18:22 Description: Root Model Module ''' from torch import nn from common.utils import OutputData, InputData from model.decoder.base_decoder import BaseDecoder from model.encoder.base_encoder import BaseEncoder class OpenSLUModel(nn.Module): def __init__(self, encoder: BaseEncoder, decoder:BaseDecoder, **config): """Create model automatedly Args: encoder (BaseEncoder): encoder created by config decoder (BaseDecoder): decoder created by config config (dict): any other args """ super().__init__() self.encoder = encoder self.decoder = decoder self.config = config def forward(self, inp: InputData) -> OutputData: """ model forward Args: inp (InputData): input ids and other information Returns: OutputData: pred logits """ return self.decoder(self.encoder(inp)) def decode(self, output: OutputData, target: InputData=None): """ decode output Args: pred (OutputData): pred logits data target (InputData): golden data Returns: decoded ids """ return self.decoder.decode(output, target) def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True): """ compute loss Args: pred (OutputData): pred logits data target (InputData): golden data compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True. compute_slot_loss (bool, optional): whether to compute slot loss. Defaults to True. Returns: loss value """ return self.decoder.compute_loss(pred, target, compute_intent_loss=compute_intent_loss, compute_slot_loss=compute_slot_loss)