Spaces:
Sleeping
Sleeping
File size: 4,999 Bytes
3bc4816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import re
from typing import List, Dict, Set
import numpy as np
import torch
from ufal.chu_liu_edmonds import chu_liu_edmonds
DEPENDENCY_RELATIONS = [
"acl",
"advcl",
"advmod",
"amod",
"appos",
"aux",
"case",
"cc",
"ccomp",
"conj",
"cop",
"csubj",
"det",
"iobj",
"mark",
"nmod",
"nsubj",
"nummod",
"obj",
"obl",
"parataxis",
"punct",
"root",
"vocative",
"xcomp",
]
INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
def preprocess_text(text: str) -> List[str]:
text = text.strip()
text = re.sub("(?<! )(?=[.,!?()·;:])|(?<=[.,!?()·;:])(?! )", r" ", text)
return text.split()
def batched_index_select(
input: torch.Tensor, dim: int, index: torch.Tensor
) -> torch.Tensor:
views = [input.shape[0]] + [
1 if i != dim else -1 for i in range(1, len(input.shape))
]
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]:
return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids]
def resolve(
edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor
) -> torch.Tensor:
multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0]
if len(multiple_roots) > 1:
main_root = max(multiple_roots, key=edmonds_head.count)
secondary_roots = set(multiple_roots) - {main_root}
for root in secondary_roots:
parent_probs_table[0][word_ids.index(root)][0] = 0
return parent_probs_table
def apply_chu_liu_edmonds(
parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int]
) -> List[int]:
parent_probs_table = (
parent_probs_table
if parent_probs_table.shape[1] == parent_probs_table.shape[2]
else parent_probs_table[:, :, 1:]
)
edmonds_heads, _ = chu_liu_edmonds(
parent_probs_table.squeeze(0).cpu().numpy().astype("double")
)
edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0)
edmonds_heads[edmonds_heads == -1] = 0
tokenized_input["head_labels"] = edmonds_heads
return get_relevant_tokens(edmonds_heads[0], start_ids)
def get_word_endings(tokenized_input):
word_ids = tokenized_input.word_ids(batch_index=0)
start_ids = set()
word_endings = {0: (1, 0)}
for word_id in word_ids:
if word_id is not None:
start, end = tokenized_input.word_to_tokens(
batch_or_word_index=0, word_index=word_id
)
start_ids.add(start)
word_endings[start] = (end, word_id + 1)
for a in range(start + 1, end + 1):
word_endings[a] = (end, word_id + 1)
return word_endings, start_ids, word_ids
def get_dependencies(
dependency_parser,
label_parser,
tokenizer,
collator,
labels: bool,
sentence: List[str],
) -> Dict:
tokenized_input = tokenizer(
sentence, truncation=True, is_split_into_words=True, add_special_tokens=True
)
dep_dict: Dict[str, List[Dict[str, str]]] = {
"words": [{"text": "ROOT", "tag": ""}],
"arcs": [],
}
word_endings, start_ids, word_ids = get_word_endings(tokenized_input)
tokenized_input = collator([tokenized_input])
_, _, parent_probs_table = dependency_parser(**tokenized_input)
irrelevant = torch.tensor(
[
idx.item()
for idx in torch.arange(parent_probs_table.size(1))
if idx.item() not in start_ids and idx.item() != 0
]
)
if irrelevant.nelement() > 0:
parent_probs_table.index_fill_(1, irrelevant, torch.nan)
parent_probs_table.index_fill_(2, irrelevant, torch.nan)
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table)
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
if labels:
predictions_labels = np.argmax(
label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1
)
predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids)
predicted_relations = [
INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence))
]
else:
predicted_relations = [""] * len(sentence)
for idx, head in enumerate(edmonds_head):
arc = {
"start": min(idx + 1, word_endings[head][1]),
"end": max(idx + 1, word_endings[head][1]),
"label": predicted_relations[idx],
"dir": "left" if idx + 1 < word_endings[head][1] else "right",
}
dep_dict["arcs"].append(arc)
return dep_dict
|