#!/usr/bin/env python3 # coding=utf-8 import math import torch import torch.nn as nn import torch.nn.functional as F from model.module.edge_classifier import EdgeClassifier from model.module.anchor_classifier import AnchorClassifier from utility.cross_entropy import cross_entropy, binary_cross_entropy from utility.hungarian_matching import get_matching, reorder, match_anchor, match_label from utility.utils import create_padding_mask class AbstractHead(nn.Module): def __init__(self, dataset, args, config, initialize: bool): super(AbstractHead, self).__init__() self.edge_classifier = self.init_edge_classifier(dataset, args, config, initialize) self.label_classifier = self.init_label_classifier(dataset, args, config, initialize) self.anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="anchor") self.source_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="source_anchor") self.target_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="target_anchor") self.query_length = args.query_length self.focal = args.focal self.dataset = dataset def forward(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch): output = {} decoder_lens = self.query_length * batch["every_input"][1] output["label"] = self.forward_label(decoder_output) output["anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w) output["source_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w) output["target_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w) cost_matrices = self.create_cost_matrices(output, batch, decoder_lens) matching = get_matching(cost_matrices) decoder_output = reorder(decoder_output, matching, batch["labels"][0].size(1)) output["edge presence"], output["edge label"] = self.forward_edge(decoder_output) return self.loss(output, batch, matching, decoder_mask) def predict(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch, **kwargs): every_input, word_lens = batch["every_input"] decoder_lens = self.query_length * word_lens batch_size = every_input.size(0) label_pred = self.forward_label(decoder_output) anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w) source_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w) target_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w) labels = [[] for _ in range(batch_size)] anchors, source_anchors, target_anchors = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)], [[] for _ in range(batch_size)] for b in range(batch_size): label_indices = self.inference_label(label_pred[b, :decoder_lens[b], :]).cpu() for t in range(label_indices.size(0)): label_index = label_indices[t].item() if label_index == 0: continue decoder_output[b, len(labels[b]), :] = decoder_output[b, t, :] labels[b].append(label_index) if anchor_pred is None: anchors[b].append(list(range(t // self.query_length, word_lens[b]))) else: anchors[b].append(self.inference_anchor(anchor_pred[b, t, :word_lens[b]]).cpu()) if source_anchor_pred is None: source_anchors[b].append(list(range(t // self.query_length, word_lens[b]))) else: source_anchors[b].append(self.inference_anchor(source_anchor_pred[b, t, :word_lens[b]]).cpu()) if target_anchor_pred is None: target_anchors[b].append(list(range(t // self.query_length, word_lens[b]))) else: target_anchors[b].append(self.inference_anchor(target_anchor_pred[b, t, :word_lens[b]]).cpu()) decoder_output = decoder_output[:, : max(len(l) for l in labels), :] edge_presence, edge_labels = self.forward_edge(decoder_output) outputs = [ self.parser.parse( { "labels": labels[b], "anchors": anchors[b], "source anchors": source_anchors[b], "target anchors": target_anchors[b], "edge presence": self.inference_edge_presence(edge_presence, b), "edge labels": self.inference_edge_label(edge_labels, b), "id": batch["id"][b].cpu(), "tokens": batch["every_input"][0][b, : word_lens[b]].cpu(), "token intervals": batch["token_intervals"][b, :, :].cpu(), }, **kwargs ) for b in range(batch_size) ] return outputs def loss(self, output, batch, matching, decoder_mask): batch_size = batch["every_input"][0].size(0) device = batch["every_input"][0].device T_label = batch["labels"][0].size(1) T_input = batch["every_input"][0].size(1) T_edge = batch["edge_presence"].size(1) input_mask = create_padding_mask(batch_size, T_input, batch["every_input"][1], device) # shape: (B, T_input) label_mask = create_padding_mask(batch_size, T_label, batch["labels"][1], device) # shape: (B, T_label) edge_mask = torch.eye(T_label, T_label, device=device, dtype=torch.bool).unsqueeze(0) # shape: (1, T_label, T_label) edge_mask = edge_mask | label_mask.unsqueeze(1) | label_mask.unsqueeze(2) # shape: (B, T_label, T_label) if T_edge != T_label: edge_mask = F.pad(edge_mask, (T_edge - T_label, 0, T_edge - T_label, 0), value=0) edge_label_mask = (batch["edge_presence"] == 0) | edge_mask if output["edge label"] is not None: batch["edge_labels"] = ( batch["edge_labels"][0][:, :, :, :output["edge label"].size(-1)], batch["edge_labels"][1], ) losses = {} losses.update(self.loss_label(output, batch, decoder_mask, matching)) losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="anchor")) losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="source_anchor")) losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="target_anchor")) losses.update(self.loss_edge_presence(output, batch, edge_mask)) losses.update(self.loss_edge_label(output, batch, edge_label_mask.unsqueeze(-1))) stats = {f"{key}": value.detach().cpu().item() for key, value in losses.items()} total_loss = sum(losses.values()) / len(losses) return total_loss, stats @torch.no_grad() def create_cost_matrices(self, output, batch, decoder_lens): batch_size = len(batch["labels"][1]) decoder_lens = decoder_lens.cpu() matrices = [] for b in range(batch_size): label_cost_matrix = self.label_cost_matrix(output, batch, decoder_lens, b) anchor_cost_matrix = self.anchor_cost_matrix(output, batch, decoder_lens, b) cost_matrix = label_cost_matrix * anchor_cost_matrix matrices.append(cost_matrix.cpu()) return matrices def init_edge_classifier(self, dataset, args, config, initialize: bool): if not config["edge presence"] and not config["edge label"]: return None return EdgeClassifier(dataset, args, initialize, presence=config["edge presence"], label=config["edge label"]) def init_label_classifier(self, dataset, args, config, initialize: bool): if not config["label"]: return None classifier = nn.Sequential( nn.Dropout(args.dropout_label), nn.Linear(args.hidden_size, len(dataset.label_field.vocab) + 1, bias=True) ) if initialize: classifier[1].bias.data = dataset.label_freqs.log() return classifier def init_anchor_classifier(self, dataset, args, config, initialize: bool, mode="anchor"): if not config[mode]: return None return AnchorClassifier(dataset, args, initialize, mode=mode) def forward_edge(self, decoder_output): if self.edge_classifier is None: return None, None return self.edge_classifier(decoder_output) def forward_label(self, decoder_output): if self.label_classifier is None: return None return torch.log_softmax(self.label_classifier(decoder_output), dim=-1) def forward_anchor(self, decoder_output, encoder_output, encoder_mask, mode="anchor"): classifier = getattr(self, f"{mode}_classifier") if classifier is None: return None return classifier(decoder_output, encoder_output, encoder_mask) def inference_label(self, prediction): prediction = prediction.exp() return torch.where( prediction[:, 0] > prediction[:, 1:].sum(-1), torch.zeros(prediction.size(0), dtype=torch.long, device=prediction.device), prediction[:, 1:].argmax(dim=-1) + 1 ) def inference_anchor(self, prediction): return prediction.sigmoid() def inference_edge_presence(self, prediction, example_index: int): if prediction is None: return None N = prediction.size(1) mask = torch.eye(N, N, device=prediction.device, dtype=torch.bool) return prediction[example_index, :, :].sigmoid().masked_fill(mask, 0.0).cpu() def inference_edge_label(self, prediction, example_index: int): if prediction is None: return None return prediction[example_index, :, :, :].cpu() def loss_edge_presence(self, prediction, target, mask): if self.edge_classifier is None or prediction["edge presence"] is None: return {} return {"edge presence": binary_cross_entropy(prediction["edge presence"], target["edge_presence"].float(), mask)} def loss_edge_label(self, prediction, target, mask): if self.edge_classifier is None or prediction["edge label"] is None: return {} return {"edge label": binary_cross_entropy(prediction["edge label"], target["edge_labels"][0].float(), mask)} def loss_label(self, prediction, target, mask, matching): if self.label_classifier is None or prediction["label"] is None: return {} prediction = prediction["label"] target = match_label( target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length ) return {"label": cross_entropy(prediction, target, mask, focal=self.focal)} def loss_anchor(self, prediction, target, mask, matching, mode="anchor"): if getattr(self, f"{mode}_classifier") is None or prediction[mode] is None: return {} prediction = prediction[mode] target, anchor_mask = match_anchor(target[mode], matching, prediction.shape, prediction.device) mask = anchor_mask.unsqueeze(-1) | mask.unsqueeze(-2) return {mode: binary_cross_entropy(prediction, target.float(), mask)} def label_cost_matrix(self, output, batch, decoder_lens, b: int): if output["label"] is None: return 1.0 target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, num_classes) label_prob = output["label"][b, : decoder_lens[b], :].exp().unsqueeze(0) # shape: (1, num_queries, num_classes) tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, num_classes) cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes) return cost_matrix def anchor_cost_matrix(self, output, batch, decoder_lens, b: int): if output["anchor"] is None: return 1.0 num_nodes = batch["labels"][1][b] word_lens = batch["every_input"][1] target_anchors, _ = batch["anchor"] pred_anchors = output["anchor"].sigmoid() tgt_align = target_anchors[b, : num_nodes, : word_lens[b]] # shape: (num_nodes, num_inputs) align_prob = pred_anchors[b, : decoder_lens[b], : word_lens[b]] # shape: (num_queries, num_inputs) align_prob = align_prob.unsqueeze(1).expand(-1, num_nodes, -1) # shape: (num_queries, num_nodes, num_inputs) align_prob = torch.where(tgt_align.unsqueeze(0).bool(), align_prob, 1.0 - align_prob) # shape: (num_queries, num_nodes, num_inputs) cost_matrix = align_prob.log().mean(-1).exp() # shape: (num_queries, num_nodes) return cost_matrix