Spaces:
Sleeping
Sleeping
import re | |
from typing import List, Dict, Set | |
import numpy as np | |
import torch | |
from ufal.chu_liu_edmonds import chu_liu_edmonds | |
DEPENDENCY_RELATIONS = [ | |
"acl", | |
"advcl", | |
"advmod", | |
"amod", | |
"appos", | |
"aux", | |
"case", | |
"cc", | |
"ccomp", | |
"conj", | |
"cop", | |
"csubj", | |
"det", | |
"iobj", | |
"mark", | |
"nmod", | |
"nsubj", | |
"nummod", | |
"obj", | |
"obl", | |
"parataxis", | |
"punct", | |
"root", | |
"vocative", | |
"xcomp", | |
] | |
INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)} | |
TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)} | |
def preprocess_text(text: str) -> List[str]: | |
text = text.strip() | |
text = re.sub("(?<! )(?=[.,!?()路;:])|(?<=[.,!?()路;:])(?! )", r" ", text) | |
return text.split() | |
def batched_index_select( | |
input: torch.Tensor, dim: int, index: torch.Tensor | |
) -> torch.Tensor: | |
views = [input.shape[0]] + [ | |
1 if i != dim else -1 for i in range(1, len(input.shape)) | |
] | |
expanse = list(input.shape) | |
expanse[0] = -1 | |
expanse[dim] = -1 | |
index = index.view(views).expand(expanse) | |
return torch.gather(input, dim, index) | |
def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]: | |
return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids] | |
def resolve( | |
edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor | |
) -> torch.Tensor: | |
multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0] | |
if len(multiple_roots) > 1: | |
main_root = max(multiple_roots, key=edmonds_head.count) | |
secondary_roots = set(multiple_roots) - {main_root} | |
for root in secondary_roots: | |
parent_probs_table[0][word_ids.index(root)][0] = 0 | |
return parent_probs_table | |
def apply_chu_liu_edmonds( | |
parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int] | |
) -> List[int]: | |
parent_probs_table = ( | |
parent_probs_table | |
if parent_probs_table.shape[1] == parent_probs_table.shape[2] | |
else parent_probs_table[:, :, 1:] | |
) | |
edmonds_heads, _ = chu_liu_edmonds( | |
parent_probs_table.squeeze(0).cpu().numpy().astype("double") | |
) | |
edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0) | |
edmonds_heads[edmonds_heads == -1] = 0 | |
tokenized_input["head_labels"] = edmonds_heads | |
return get_relevant_tokens(edmonds_heads[0], start_ids) | |
def get_word_endings(tokenized_input): | |
word_ids = tokenized_input.word_ids(batch_index=0) | |
start_ids = set() | |
word_endings = {0: (1, 0)} | |
for word_id in word_ids: | |
if word_id is not None: | |
start, end = tokenized_input.word_to_tokens( | |
batch_or_word_index=0, word_index=word_id | |
) | |
start_ids.add(start) | |
word_endings[start] = (end, word_id + 1) | |
for a in range(start + 1, end + 1): | |
word_endings[a] = (end, word_id + 1) | |
return word_endings, start_ids, word_ids | |
def get_dependencies( | |
dependency_parser, | |
label_parser, | |
tokenizer, | |
collator, | |
labels: bool, | |
sentence: List[str], | |
) -> Dict: | |
tokenized_input = tokenizer( | |
sentence, truncation=True, is_split_into_words=True, add_special_tokens=True | |
) | |
dep_dict: Dict[str, List[Dict[str, str]]] = { | |
"words": [{"text": "ROOT", "tag": ""}], | |
"arcs": [], | |
} | |
word_endings, start_ids, word_ids = get_word_endings(tokenized_input) | |
tokenized_input = collator([tokenized_input]) | |
_, _, parent_probs_table = dependency_parser(**tokenized_input) | |
irrelevant = torch.tensor( | |
[ | |
idx.item() | |
for idx in torch.arange(parent_probs_table.size(1)) | |
if idx.item() not in start_ids and idx.item() != 0 | |
] | |
) | |
if irrelevant.nelement() > 0: | |
parent_probs_table.index_fill_(1, irrelevant, torch.nan) | |
parent_probs_table.index_fill_(2, irrelevant, torch.nan) | |
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids) | |
parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table) | |
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids) | |
if labels: | |
predictions_labels = np.argmax( | |
label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1 | |
) | |
predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids) | |
predicted_relations = [ | |
INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence)) | |
] | |
else: | |
predicted_relations = [""] * len(sentence) | |
for idx, head in enumerate(edmonds_head): | |
arc = { | |
"start": min(idx + 1, word_endings[head][1]), | |
"end": max(idx + 1, word_endings[head][1]), | |
"label": predicted_relations[idx], | |
"dir": "left" if idx + 1 < word_endings[head][1] else "right", | |
} | |
dep_dict["arcs"].append(arc) | |
return dep_dict | |