import re from typing import List, Dict, Set import numpy as np import torch from ufal.chu_liu_edmonds import chu_liu_edmonds DEPENDENCY_RELATIONS = [ "acl", "advcl", "advmod", "amod", "appos", "aux", "case", "cc", "ccomp", "conj", "cop", "csubj", "det", "iobj", "mark", "nmod", "nsubj", "nummod", "obj", "obl", "parataxis", "punct", "root", "vocative", "xcomp", ] INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)} TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)} def preprocess_text(text: str) -> List[str]: text = text.strip() text = re.sub("(? torch.Tensor: views = [input.shape[0]] + [ 1 if i != dim else -1 for i in range(1, len(input.shape)) ] expanse = list(input.shape) expanse[0] = -1 expanse[dim] = -1 index = index.view(views).expand(expanse) return torch.gather(input, dim, index) def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]: return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids] def resolve( edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor ) -> torch.Tensor: multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0] if len(multiple_roots) > 1: main_root = max(multiple_roots, key=edmonds_head.count) secondary_roots = set(multiple_roots) - {main_root} for root in secondary_roots: parent_probs_table[0][word_ids.index(root)][0] = 0 return parent_probs_table def apply_chu_liu_edmonds( parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int] ) -> List[int]: parent_probs_table = ( parent_probs_table if parent_probs_table.shape[1] == parent_probs_table.shape[2] else parent_probs_table[:, :, 1:] ) edmonds_heads, _ = chu_liu_edmonds( parent_probs_table.squeeze(0).cpu().numpy().astype("double") ) edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0) edmonds_heads[edmonds_heads == -1] = 0 tokenized_input["head_labels"] = edmonds_heads return get_relevant_tokens(edmonds_heads[0], start_ids) def get_word_endings(tokenized_input): word_ids = tokenized_input.word_ids(batch_index=0) start_ids = set() word_endings = {0: (1, 0)} for word_id in word_ids: if word_id is not None: start, end = tokenized_input.word_to_tokens( batch_or_word_index=0, word_index=word_id ) start_ids.add(start) word_endings[start] = (end, word_id + 1) for a in range(start + 1, end + 1): word_endings[a] = (end, word_id + 1) return word_endings, start_ids, word_ids def get_dependencies( dependency_parser, label_parser, tokenizer, collator, labels: bool, sentence: List[str], ) -> Dict: tokenized_input = tokenizer( sentence, truncation=True, is_split_into_words=True, add_special_tokens=True ) dep_dict: Dict[str, List[Dict[str, str]]] = { "words": [{"text": "ROOT", "tag": ""}], "arcs": [], } word_endings, start_ids, word_ids = get_word_endings(tokenized_input) tokenized_input = collator([tokenized_input]) _, _, parent_probs_table = dependency_parser(**tokenized_input) irrelevant = torch.tensor( [ idx.item() for idx in torch.arange(parent_probs_table.size(1)) if idx.item() not in start_ids and idx.item() != 0 ] ) if irrelevant.nelement() > 0: parent_probs_table.index_fill_(1, irrelevant, torch.nan) parent_probs_table.index_fill_(2, irrelevant, torch.nan) edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids) parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table) edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids) if labels: predictions_labels = np.argmax( label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1 ) predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids) predicted_relations = [ INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence)) ] else: predicted_relations = [""] * len(sentence) for idx, head in enumerate(edmonds_head): arc = { "start": min(idx + 1, word_endings[head][1]), "end": max(idx + 1, word_endings[head][1]), "label": predicted_relations[idx], "dir": "left" if idx + 1 < word_endings[head][1] else "right", } dep_dict["arcs"].append(arc) return dep_dict