athenas-lens / utils.py
bowphs's picture
Add initial attempt of a code framework.
3bc4816
raw
history blame
5 kB
import re
from typing import List, Dict, Set
import numpy as np
import torch
from ufal.chu_liu_edmonds import chu_liu_edmonds
DEPENDENCY_RELATIONS = [
"acl",
"advcl",
"advmod",
"amod",
"appos",
"aux",
"case",
"cc",
"ccomp",
"conj",
"cop",
"csubj",
"det",
"iobj",
"mark",
"nmod",
"nsubj",
"nummod",
"obj",
"obl",
"parataxis",
"punct",
"root",
"vocative",
"xcomp",
]
INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
def preprocess_text(text: str) -> List[str]:
text = text.strip()
text = re.sub("(?<! )(?=[.,!?()路;:])|(?<=[.,!?()路;:])(?! )", r" ", text)
return text.split()
def batched_index_select(
input: torch.Tensor, dim: int, index: torch.Tensor
) -> torch.Tensor:
views = [input.shape[0]] + [
1 if i != dim else -1 for i in range(1, len(input.shape))
]
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]:
return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids]
def resolve(
edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor
) -> torch.Tensor:
multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0]
if len(multiple_roots) > 1:
main_root = max(multiple_roots, key=edmonds_head.count)
secondary_roots = set(multiple_roots) - {main_root}
for root in secondary_roots:
parent_probs_table[0][word_ids.index(root)][0] = 0
return parent_probs_table
def apply_chu_liu_edmonds(
parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int]
) -> List[int]:
parent_probs_table = (
parent_probs_table
if parent_probs_table.shape[1] == parent_probs_table.shape[2]
else parent_probs_table[:, :, 1:]
)
edmonds_heads, _ = chu_liu_edmonds(
parent_probs_table.squeeze(0).cpu().numpy().astype("double")
)
edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0)
edmonds_heads[edmonds_heads == -1] = 0
tokenized_input["head_labels"] = edmonds_heads
return get_relevant_tokens(edmonds_heads[0], start_ids)
def get_word_endings(tokenized_input):
word_ids = tokenized_input.word_ids(batch_index=0)
start_ids = set()
word_endings = {0: (1, 0)}
for word_id in word_ids:
if word_id is not None:
start, end = tokenized_input.word_to_tokens(
batch_or_word_index=0, word_index=word_id
)
start_ids.add(start)
word_endings[start] = (end, word_id + 1)
for a in range(start + 1, end + 1):
word_endings[a] = (end, word_id + 1)
return word_endings, start_ids, word_ids
def get_dependencies(
dependency_parser,
label_parser,
tokenizer,
collator,
labels: bool,
sentence: List[str],
) -> Dict:
tokenized_input = tokenizer(
sentence, truncation=True, is_split_into_words=True, add_special_tokens=True
)
dep_dict: Dict[str, List[Dict[str, str]]] = {
"words": [{"text": "ROOT", "tag": ""}],
"arcs": [],
}
word_endings, start_ids, word_ids = get_word_endings(tokenized_input)
tokenized_input = collator([tokenized_input])
_, _, parent_probs_table = dependency_parser(**tokenized_input)
irrelevant = torch.tensor(
[
idx.item()
for idx in torch.arange(parent_probs_table.size(1))
if idx.item() not in start_ids and idx.item() != 0
]
)
if irrelevant.nelement() > 0:
parent_probs_table.index_fill_(1, irrelevant, torch.nan)
parent_probs_table.index_fill_(2, irrelevant, torch.nan)
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table)
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
if labels:
predictions_labels = np.argmax(
label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1
)
predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids)
predicted_relations = [
INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence))
]
else:
predicted_relations = [""] * len(sentence)
for idx, head in enumerate(edmonds_head):
arc = {
"start": min(idx + 1, word_endings[head][1]),
"end": max(idx + 1, word_endings[head][1]),
"label": predicted_relations[idx],
"dir": "left" if idx + 1 < word_endings[head][1] else "right",
}
dep_dict["arcs"].append(arc)
return dep_dict