# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Implements tracking of constraints for a beam item. A list of constraints is given as a list of one or more token sequences, each of length at least one token. For example, for an input sentence > Die maschinelle Übersetzung ist schwer zu kontrollieren. We could have the constraints: * to influence * hard There are two implementations: * OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. * UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. The difference is that in the first, the constraints are assumed to be in order; the algorithm will permit zero or more tokens between them. In the second, the constraints are not ordered, so many orderings will be explored. The same sequence can be present any number of times, and will appear that many times in the output. """ from collections import Counter from typing import List, Optional, Set, Tuple import torch class ConstraintState: def __init__(self): pass def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: """Takes a list of list of constraints in tensor form (a list of tensor constraints for each sentence) and transforms it into a packed Tensor. For example, here is a batch of size 3 with 3, 0, and 1 constraints: [ [ [3 1 2], [3], [4 5 6 7], ] [], [ [1 8 9 10 1 4 11 12], ] ] Its corresponding packed structure is: [ [ 3 3 1 2 0 3 0 4 5 6 7 0], [ 0 0 0 0 0 0 0 0 0 0 0 0], [ 1 1 8 9 10 1 4 11 12 0 0 0] ] The packed tensor has shape (batch size, maxlen), where maxlen is defined below. Each row contains concatenated constraint tokens for that sentence, with 0 appended after each constraint. The first item in each row is the number of constraints for that sentence. So maxlen is the maximum of (number of constraints) + (sum length of constraints) + 1. across all sentences in the batch. """ # The maximum word length of concatenated constraints for any sentence max_constraints_len = 1 for sentence_constraints in batch_constraints: if len(sentence_constraints): # number of constraints, plus sum of constrain lens, plus a zero after each constraints_len = ( 1 + sum([c.size(0) for c in sentence_constraints]) + len(sentence_constraints) ) max_constraints_len = max(max_constraints_len, constraints_len) batch_size = len(batch_constraints) constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() for i, sentence_constraints in enumerate(batch_constraints): constraints_tensor[i, 0] = len(sentence_constraints) offset = 1 for j, constraint in enumerate(sentence_constraints): this_len = constraint.size(0) constraints_tensor[i, offset : offset + this_len] = constraint offset += this_len + 1 return constraints_tensor.long() def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: """ Transforms *one row* of a packed constraint tensor (e.g., for one sentence in the batch) into a list of constraint tensors. """ constraint_list = [] num_constraints = constraint_tensor[0] constraints = constraint_tensor.tolist() offset = 1 for i in range(num_constraints): where = constraints.index(0, offset) constraint_list.append(constraint_tensor[offset:where]) offset = where + 1 return constraint_list class ConstraintNode: """ Represents a node in a trie managing unordered constraints. """ def __init__(self, token: int = None, parent=None): # The token associate with this node (None for the root) self.token = int(token) if token is not None else None # The parent (None at the root) self.parent = parent # Whether this node is a completed constraint self.terminal = 0 # List of child nodes self.children = {} # The cumulative number of constraints from this point in the # trie forward self.num_constraints = 0 @property def id(self): return self.token def __str__(self): term = self.terminal != 0 return f"[{self.token}].{term}#{self.num_constraints}" def __getitem__(self, key: int): return self.children.get(key, None) def next_tokens(self) -> Set[int]: """The set of child labels.""" return set(self.children.keys()) @staticmethod def create(constraints: List[List[int]]): root = ConstraintNode() for sequence in constraints: root.add_sequence(sequence) return root @staticmethod def print_graph(node: "ConstraintNode"): if len(node.children) == 0: return str(node) else: s = f"({node}" for child in node.children.values(): s += " " + ConstraintNode.print_graph(child) s += ")" return s def token_counts(self) -> Counter: """Returns a counter of the number of times each token is used in a constraint. """ token_counts = Counter() kids = list(self.children.values()) while len(kids) > 0: kid = kids.pop() token_counts[kid.id] += kid.num_constraints kids += list(kid.children.values()) return token_counts def tokens(self) -> Set[int]: """Returns the set of tokens in constraints.""" return set(self.token_counts().keys()) def add_sequence(self, sequence: List[int]): """Adds a constraint, represented as a list of integers, to the trie.""" assert len(sequence) > 0 token = int(sequence[0]) if token not in self.children: self.children[token] = ConstraintNode(token, parent=self) node = self.children[token] if len(sequence) == 1: node.terminal += 1 node.num_constraints += 1 parent = node.parent while parent is not None: parent.num_constraints += 1 parent = parent.parent else: node.add_sequence(sequence[1:]) class UnorderedConstraintState(ConstraintState): """ Records progress through the set of constraints for each item in the beam using a trie. """ def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): self.node = node if copy_from is None: # The root node self.root = node # The set of states in the graph that have been completed self.completed = Counter() # The... self.generated = Counter() # The list of tokens we need to generate self.needed_tokens = self.root.tokens() else: self.completed = Counter(copy_from.completed) self.generated = Counter(copy_from.generated) self.root = copy_from.root # Mark the node as generated if self.node != self.root: self.generated[node] += 1 @staticmethod def create(constraint_tensor: torch.Tensor): constraint_list = unpack_constraints(constraint_tensor) constraint_trie_root = ConstraintNode.create(constraint_list) return UnorderedConstraintState(constraint_trie_root) def __str__(self): gen_str = ",".join([str(node) for node in self.generated]) return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" def __copy__(self): copied_state = UnorderedConstraintState(self.node, copy_from=self) return copied_state def copy(self): return self.__copy__() @property def name(self): if self.node.id is None: return "ROOT" else: return str(self.node.id) @property def is_root(self): return self.node == self.root @property def bank(self): return sum(self.generated.values()) @property def num_completed(self): """The number of constraints (not constraint tokens) that are completed. In addition to the already-completed states, we need to account for the current state, which might get marked as completed when another token is generated. """ in_final = self.node.terminal and self.completed[self.node] < self.node.terminal return sum(self.completed.values()) + in_final @property def finished(self): return self.root.num_constraints - self.num_completed == 0 @property def token_counts(self): return self.root.token_counts() @property def tokens(self): return self.root.tokens() @property def num_constraint_tokens(self): return sum(self.token_counts.values()) def next_tokens(self) -> Set[int]: """Returns the list of tokens that could come next. These are (a) all tokens extending the root state and, for non-root states, additionally all tokens extending the current state.""" if self.node != self.root: return self.root.next_tokens().union(self.node.next_tokens()) else: return self.root.next_tokens() def advance(self, token: int): """Reads in a token and advances the state. Here's how it works. We can advance to the next state if: - there is a matching child - its path isn't blocked A path is blocked when all constraints that are descendants of that node have already been generated, in the current state. If we are not able to advance from the current state, we "fall off the graph" and return to the root state. There, we again try to advance, checking the same criteria. In any case, when falling off the graph, we need to do some bookkeeping. We: - check whether any constraints were met (all prefixes of current state) - if one is found, mark it as completed - adjust visited nodes accordingly """ token = int(token) next_state = None child = self.node[token] if child is not None and self.generated[child] < child.num_constraints: next_state = UnorderedConstraintState(child, copy_from=self) def rewind(): """If we're mid-trie and an "illegal" token is chosen next, we need to reset our state to the root state. However, along the way, we need to check whether a prefix of the current trie state represents a state we could mark as completed. """ node = self.node while node != self.root: if node.terminal and self.completed[node] < node.terminal: next_state.completed[node] += 1 return next_state.generated[node] -= 1 node = node.parent # Fall off the graph, check the root if next_state is None and token in self.root.next_tokens(): child = self.root[token] # We can only traverse this edge if it's not saturated if self.generated[child] < child.num_constraints: next_state = UnorderedConstraintState(child, copy_from=self) else: next_state = UnorderedConstraintState(self.root, copy_from=self) # Rewind rewind() elif next_state is None: next_state = UnorderedConstraintState(self.root, copy_from=self) # Rewind rewind() return next_state class ConstraintSequence: def __init__(self, sequences: List[List[int]]): """Represents a set of possibly multitoken constraints by concatenating them and internally recording the end points. """ self.sequences = [] self.endpoints = [] self.num_tokens = 0 self.tokens = set() for sequence in sequences: for token in sequence: self.tokens.add(token) self.num_tokens += len(sequence) self.endpoints += [False for x in range(len(sequence) - 1)] + [True] self.sequences += sequence def __getitem__(self, key: int): return self.sequences[key] def __len__(self): return len(self.sequences) def __str__(self): return str(self.sequences) class OrderedConstraintState(ConstraintState): """ Records progress through the set of linear nonbranching constraints with gaps. """ def __init__(self, sequence: ConstraintSequence, state: int = -1): self.sequence = sequence self.state = state @staticmethod def create(constraint_tensor: torch.Tensor): constraint_list = unpack_constraints(constraint_tensor) return OrderedConstraintState(ConstraintSequence(constraint_list), -1) def __str__(self): return f"{self.state}/{self.bank}x{self.num_completed}" def __copy__(self): return OrderedConstraintState(self.sequence, self.state) def copy(self): return self.__copy__() @property def num_completed(self): if self.state == -1: return 0 count = len( list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) ) return count @property def is_root(self): return self.state == -1 @property def name(self): if self.state == -1: return "ROOT" else: return str(self.sequence[self.state]) @property def bank(self) -> int: return self.state + 1 @property def finished(self): return self.state + 1 == len(self.sequence) @property def token_counts(self): return self.sequence.token_counts() @property def tokens(self): return self.sequence.tokens @property def num_constraint_tokens(self): return sum(self.token_counts.values()) def next_tokens(self) -> Set[int]: """Returns the list of tokens that could come next. These are (a) all tokens extending the root state and, for non-root states, additionally all tokens extending the current state.""" tokens = set() if self.state > 0: tokens.add(self.sequence[0]) if not self.finished: tokens.add(self.sequence[self.state + 1]) return tokens def advance(self, token: int): """Reads in a token and advances the state. Here's how it works. We can advance to the next state if: - there is a matching child - its path isn't blocked A path is blocked when all constraints that are descendants of that node have already been generated, in the current state. If we are not able to advance from the current state, we "fall off the graph" and return to the root state. There, we again try to advance, checking the same criteria. In any case, when falling off the graph, we need to do some bookkeeping. We: - check whether any constraints were met (all prefixes of current state) - if one is found, mark it as completed - adjust visited nodes accordingly """ token = int(token) # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") if self.finished: # Accept anything next_state = self.copy() elif self.sequence[self.state + 1] == token: # Advance to the next token next_state = OrderedConstraintState(self.sequence, self.state + 1) elif self.sequence.endpoints[self.state]: # Accept anything between constraints (*) next_state = self.copy() elif token == self.sequence[0]: # Start over having generated the first token next_state = OrderedConstraintState(self.sequence, 0) else: # Start over from the root next_state = OrderedConstraintState(self.sequence, -1) return next_state