|
''' |
|
Author: Qiguang Chen |
|
Date: 2023-01-11 10:39:26 |
|
LastEditors: Qiguang Chen |
|
LastEditTime: 2023-01-31 20:07:00 |
|
Description: |
|
|
|
''' |
|
import random |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from model.decoder import decoder_utils |
|
|
|
from torchcrf import CRF |
|
|
|
from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \ |
|
instantiate |
|
|
|
|
|
class BaseClassifier(nn.Module): |
|
"""Base class for all classifier module |
|
""" |
|
def __init__(self, **config): |
|
super().__init__() |
|
self.config = config |
|
if config.get("loss_fn"): |
|
self.loss_fn = instantiate(config.get("loss_fn")) |
|
else: |
|
self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index")) |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError("No implemented classifier.") |
|
|
|
def decode(self, output: OutputData, |
|
target: InputData = None, |
|
return_list=True, |
|
return_sentence_level=None): |
|
"""decode output logits |
|
|
|
Args: |
|
output (OutputData): output logits data |
|
target (InputData, optional): input data with attention mask. Defaults to None. |
|
return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True. |
|
return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None. |
|
|
|
Returns: |
|
List or Tensor: decoded sequence ids |
|
""" |
|
if self.config.get("return_sentence_level") is not None and return_sentence_level is None: |
|
return_sentence_level = self.config.get("return_sentence_level") |
|
elif self.config.get("return_sentence_level") is None and return_sentence_level is None: |
|
return_sentence_level = False |
|
return decoder_utils.decode(output, target, |
|
return_list=return_list, |
|
return_sentence_level=return_sentence_level, |
|
pred_type=self.config.get("mode"), |
|
use_multi=self.config.get("use_multi"), |
|
multi_threshold=self.config.get("multi_threshold")) |
|
|
|
def compute_loss(self, pred: OutputData, target: InputData): |
|
"""compute loss |
|
|
|
Args: |
|
pred (OutputData): output logits data |
|
target (InputData): input golden data |
|
|
|
Returns: |
|
Tensor: loss result |
|
""" |
|
_CRF = None |
|
if self.config.get("use_crf"): |
|
_CRF = self.CRF |
|
return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"], |
|
use_crf=_CRF is not None, |
|
ignore_index=self.config["ignore_index"], |
|
use_multi=self.config.get("use_multi"), |
|
loss_fn=self.loss_fn, |
|
CRF=_CRF) |
|
|
|
|
|
class LinearClassifier(BaseClassifier): |
|
""" |
|
Decoder structure based on Linear. |
|
""" |
|
def __init__(self, **config): |
|
"""Construction function for LinearClassifier |
|
|
|
Args: |
|
config (dict): |
|
input_dim (int): hidden state dim. |
|
use_slot (bool): whether to classify slot label. |
|
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. |
|
use_intent (bool): whether to classify intent label. |
|
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. |
|
use_crf (bool): whether to use crf for slot. |
|
""" |
|
super().__init__(**config) |
|
self.config = config |
|
if config.get("use_slot"): |
|
self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"]) |
|
if self.config.get("use_crf"): |
|
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) |
|
if config.get("use_intent"): |
|
self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"]) |
|
|
|
def forward(self, hidden: HiddenData): |
|
if self.config.get("use_intent"): |
|
return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state())) |
|
if self.config.get("use_slot"): |
|
return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state())) |
|
|
|
|
|
|
|
class AutoregressiveLSTMClassifier(BaseClassifier): |
|
""" |
|
Decoder structure based on unidirectional LSTM. |
|
""" |
|
|
|
def __init__(self, **config): |
|
""" Construction function for Decoder. |
|
|
|
Args: |
|
config (dict): |
|
input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size. |
|
use_slot (bool): whether to classify slot label. |
|
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. |
|
use_intent (bool): whether to classify intent label. |
|
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. |
|
use_crf (bool): whether to use crf for slot. |
|
hidden_dim (int): hidden dimension of iterative LSTM. |
|
embedding_dim (int): if it's not None, the input and output are relevant. |
|
dropout_rate (float): dropout rate of network which is only useful for embedding. |
|
""" |
|
|
|
super(AutoregressiveLSTMClassifier, self).__init__(**config) |
|
if config.get("use_slot") and config.get("use_crf"): |
|
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) |
|
self.input_dim = config["input_dim"] |
|
self.hidden_dim = config["hidden_dim"] |
|
if config.get("use_intent"): |
|
self.output_dim = config["intent_label_num"] |
|
if config.get("use_slot"): |
|
self.output_dim = config["slot_label_num"] |
|
self.dropout_rate = config["dropout_rate"] |
|
self.embedding_dim = config.get("embedding_dim") |
|
self.force_ratio = config.get("force_ratio") |
|
self.config = config |
|
self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100 |
|
|
|
|
|
if self.embedding_dim is not None: |
|
self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim) |
|
self.init_tensor = nn.Parameter( |
|
torch.randn(1, self.embedding_dim), |
|
requires_grad=True |
|
) |
|
|
|
|
|
if self.embedding_dim is not None: |
|
lstm_input_dim = self.input_dim + self.embedding_dim |
|
else: |
|
lstm_input_dim = self.input_dim |
|
|
|
|
|
self.dropout_layer = nn.Dropout(self.dropout_rate) |
|
self.lstm_layer = nn.LSTM( |
|
input_size=lstm_input_dim, |
|
hidden_size=self.hidden_dim, |
|
batch_first=True, |
|
bidirectional=self.config["bidirectional"], |
|
dropout=self.dropout_rate, |
|
num_layers=self.config["layer_num"] |
|
) |
|
self.linear_layer = nn.Linear( |
|
self.hidden_dim, |
|
self.output_dim |
|
) |
|
|
|
|
|
def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args): |
|
""" Forward process for decoder. |
|
|
|
:param internal_interaction: |
|
:param hidden: |
|
:return: is distribution of prediction labels. |
|
""" |
|
input_tensor = hidden.slot_hidden |
|
seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist() |
|
output_tensor_list, sent_start_pos = [], 0 |
|
input_tensor = pack_sequence(input_tensor, seq_lens) |
|
forced_input = None |
|
if self.training: |
|
if random.random() < self.force_ratio: |
|
if self.config["mode"]=="slot": |
|
|
|
forced_slot = pack_sequence(hidden.inputs.slot, seq_lens) |
|
temp_slot = [] |
|
for index, x in enumerate(forced_slot): |
|
if index == 0: |
|
temp_slot.append(x.reshape(1)) |
|
elif x == self.ignore_index: |
|
temp_slot.append(temp_slot[-1]) |
|
else: |
|
temp_slot.append(x.reshape(1)) |
|
forced_input = torch.cat(temp_slot, 0) |
|
if self.config["mode"]=="token-level-intent": |
|
forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1]) |
|
forced_input = pack_sequence(forced_intent, seq_lens) |
|
if self.embedding_dim is None or forced_input is not None: |
|
|
|
for sent_i in range(0, len(seq_lens)): |
|
sent_end_pos = sent_start_pos + seq_lens[sent_i] |
|
|
|
|
|
seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :] |
|
|
|
if self.embedding_dim is not None and forced_input is not None: |
|
if seq_lens[sent_i] > 1: |
|
seg_forced_input = forced_input[sent_start_pos: sent_end_pos] |
|
|
|
seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1] |
|
seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0) |
|
else: |
|
seg_prev_tensor = self.init_tensor |
|
|
|
|
|
combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1) |
|
else: |
|
combined_input = seg_hiddens |
|
dropout_input = self.dropout_layer(combined_input) |
|
lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1)) |
|
if internal_interaction is not None: |
|
interaction_args["sent_id"] = sent_i |
|
lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0] |
|
linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1)) |
|
|
|
output_tensor_list.append(linear_out) |
|
sent_start_pos = sent_end_pos |
|
else: |
|
for sent_i in range(0, len(seq_lens)): |
|
prev_tensor = self.init_tensor |
|
|
|
|
|
|
|
last_h, last_c = None, None |
|
|
|
sent_end_pos = sent_start_pos + seq_lens[sent_i] |
|
for word_i in range(sent_start_pos, sent_end_pos): |
|
seg_input = input_tensor[[word_i], :] |
|
combined_input = torch.cat([seg_input, prev_tensor], dim=1) |
|
dropout_input = self.dropout_layer(combined_input).view(1, 1, -1) |
|
if last_h is None and last_c is None: |
|
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input) |
|
else: |
|
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c)) |
|
|
|
if internal_interaction is not None: |
|
interaction_args["sent_id"] = sent_i |
|
lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0] |
|
|
|
lstm_out = self.linear_layer(lstm_out.view(1, -1)) |
|
output_tensor_list.append(lstm_out) |
|
|
|
_, index = lstm_out.topk(1, dim=1) |
|
prev_tensor = self.embedding_layer(index).view(1, -1) |
|
sent_start_pos = sent_end_pos |
|
seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens) |
|
|
|
if self.config.get("use_multi"): |
|
pred_output = ClassifierOutputData(seq_unpacked) |
|
else: |
|
pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1)) |
|
return pred_output |
|
|
|
|
|
class MLPClassifier(BaseClassifier): |
|
""" |
|
Decoder structure based on MLP. |
|
""" |
|
def __init__(self, **config): |
|
""" Construction function for Decoder. |
|
|
|
Args: |
|
config (dict): |
|
use_slot (bool): whether to classify slot label. |
|
use_intent (bool): whether to classify intent label. |
|
mlp (List): |
|
|
|
- _model_target_: torch.nn.Linear |
|
|
|
in_features (int): input feature dim |
|
|
|
out_features (int): output feature dim |
|
|
|
- _model_target_: torch.nn.LeakyReLU |
|
|
|
negative_slope: 0.2 |
|
|
|
- ... |
|
""" |
|
super(MLPClassifier, self).__init__(**config) |
|
self.config = config |
|
for i, x in enumerate(config["mlp"]): |
|
if isinstance(x.get("in_features"), str): |
|
config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]] |
|
if isinstance(x.get("out_features"), str): |
|
config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]] |
|
mlp = [instantiate(x) for x in config["mlp"]] |
|
self.seq = nn.Sequential(*mlp) |
|
|
|
|
|
def forward(self, hidden: HiddenData): |
|
if self.config.get("use_intent"): |
|
res = self.seq(hidden.intent_hidden) |
|
else: |
|
res = self.seq(hidden.slot_hidden) |
|
return ClassifierOutputData(res) |
|
|