|
import random |
|
import re |
|
from collections import defaultdict |
|
from typing import Iterable, Iterator, List, MutableSet, Optional, Tuple, TypeVar, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from rex.data.collate_fn import GeneralCollateFn |
|
from rex.data.transforms.base import CachedTransformBase, CachedTransformOneBase |
|
from rex.metrics import calc_p_r_f1_from_tp_fp_fn |
|
from rex.utils.io import load_json |
|
from rex.utils.iteration import windowed_queue_iter |
|
from rex.utils.logging import logger |
|
from transformers import AutoTokenizer |
|
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast |
|
from transformers.models.deberta_v2.tokenization_deberta_v2_fast import ( |
|
DebertaV2TokenizerFast, |
|
) |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
from src.utils import ( |
|
decode_nnw_nsw_thw_mat, |
|
decode_nnw_thw_mat, |
|
encode_nnw_nsw_thw_mat, |
|
encode_nnw_thw_mat, |
|
) |
|
|
|
Filled = TypeVar("Filled") |
|
|
|
|
|
class PaddingMixin: |
|
max_seq_len: int |
|
|
|
def pad_seq(self, batch_seqs: Iterable[Filled], fill: Filled) -> Iterable[Filled]: |
|
max_len = max(len(seq) for seq in batch_seqs) |
|
assert max_len <= self.max_seq_len |
|
for i in range(len(batch_seqs)): |
|
batch_seqs[i] = batch_seqs[i] + [fill] * (max_len - len(batch_seqs[i])) |
|
return batch_seqs |
|
|
|
def pad_mat( |
|
self, mats: List[torch.Tensor], fill: Union[int, float] |
|
) -> List[torch.Tensor]: |
|
max_len = max(mat.shape[0] for mat in mats) |
|
assert max_len <= self.max_seq_len |
|
for i in range(len(mats)): |
|
num_add = max_len - mats[i].shape[0] |
|
mats[i] = F.pad( |
|
mats[i], (0, 0, 0, num_add, 0, num_add), mode="constant", value=fill |
|
) |
|
return mats |
|
|
|
|
|
class PointerTransformMixin: |
|
tokenizer: BertTokenizerFast |
|
max_seq_len: int |
|
space_token: str = "[unused1]" |
|
|
|
def build_ins( |
|
self, |
|
query_tokens: list[str], |
|
context_tokens: list[str], |
|
answer_indexes: list[list[int]], |
|
add_context_tokens: list[str] = None, |
|
) -> Tuple: |
|
|
|
reserved_seq_len = self.max_seq_len - 3 - len(query_tokens) |
|
|
|
if reserved_seq_len < 20: |
|
raise ValueError( |
|
f"Query {query_tokens} too long: {len(query_tokens)} " |
|
f"while max seq len is {self.max_seq_len}" |
|
) |
|
|
|
input_tokens = [self.tokenizer.cls_token] |
|
input_tokens += query_tokens |
|
input_tokens += [self.tokenizer.sep_token] |
|
offset = len(input_tokens) |
|
input_tokens += context_tokens[:reserved_seq_len] |
|
available_token_range = range( |
|
offset, offset + len(context_tokens[:reserved_seq_len]) |
|
) |
|
input_tokens += [self.tokenizer.sep_token] |
|
|
|
add_context_len = 0 |
|
max_add_context_len = self.max_seq_len - len(input_tokens) - 1 |
|
add_context_flag = False |
|
if add_context_tokens and len(add_context_tokens) > 0: |
|
add_context_flag = True |
|
add_context_len = len(add_context_tokens[:max_add_context_len]) |
|
input_tokens += add_context_tokens[:max_add_context_len] |
|
input_tokens += [self.tokenizer.sep_token] |
|
new_tokens = [] |
|
for t in input_tokens: |
|
if len(t.strip()) > 0: |
|
new_tokens.append(t) |
|
else: |
|
new_tokens.append(self.space_token) |
|
input_tokens = new_tokens |
|
input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) |
|
|
|
mask = [1] |
|
mask += [2] * len(query_tokens) |
|
mask += [3] |
|
mask += [4] * len(context_tokens[:reserved_seq_len]) |
|
mask += [5] |
|
if add_context_flag: |
|
mask += [6] * add_context_len |
|
mask += [7] |
|
assert len(mask) == len(input_ids) <= self.max_seq_len |
|
|
|
available_spans = [tuple(i + offset for i in index) for index in answer_indexes] |
|
available_spans = list( |
|
filter( |
|
lambda index: all(i in available_token_range for i in index), |
|
available_spans, |
|
) |
|
) |
|
|
|
token_len = len(input_ids) |
|
pad_len = self.max_seq_len - token_len |
|
input_tokens += pad_len * [self.tokenizer.pad_token] |
|
input_ids += pad_len * [self.tokenizer.pad_token_id] |
|
mask += pad_len * [0] |
|
|
|
return input_tokens, input_ids, mask, offset, available_spans |
|
|
|
def update_labels(self, data: dict) -> dict: |
|
bs = len(data["input_ids"]) |
|
seq_len = self.max_seq_len |
|
labels = torch.zeros((bs, 2, seq_len, seq_len)) |
|
for i, batch_spans in enumerate(data["available_spans"]): |
|
|
|
|
|
|
|
for span in batch_spans: |
|
if len(span) == 1: |
|
labels[i, :, span[0], span[0]] = 1 |
|
else: |
|
for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
|
labels[i, 0, s, e] = 1 |
|
labels[i, 1, span[-1], span[0]] = 1 |
|
|
|
|
|
|
|
|
|
data["labels"] = labels |
|
return data |
|
|
|
def update_consecutive_span_labels(self, data: dict) -> dict: |
|
bs = len(data["input_ids"]) |
|
seq_len = self.max_seq_len |
|
labels = torch.zeros((bs, 1, seq_len, seq_len)) |
|
for i, batch_spans in enumerate(data["available_spans"]): |
|
for span in batch_spans: |
|
assert span == tuple(sorted(set(span))) |
|
if len(span) == 1: |
|
labels[i, 0, span[0], span[0]] = 1 |
|
else: |
|
labels[i, 0, span[0], span[-1]] = 1 |
|
data["labels"] = labels |
|
return data |
|
|
|
|
|
class CachedPointerTaggingTransform(CachedTransformBase, PointerTransformMixin): |
|
def __init__( |
|
self, |
|
max_seq_len: int, |
|
plm_dir: str, |
|
ent_type2query_filepath: str, |
|
mode: str = "w2", |
|
negative_sample_prob: float = 1.0, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.max_seq_len: int = max_seq_len |
|
self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) |
|
self.ent_type2query: dict = load_json(ent_type2query_filepath) |
|
self.negative_sample_prob = negative_sample_prob |
|
|
|
self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
|
{ |
|
"input_ids": torch.long, |
|
"mask": torch.long, |
|
"labels": torch.long, |
|
}, |
|
guessing=False, |
|
missing_key_as_null=True, |
|
) |
|
if mode == "w2": |
|
self.collate_fn.update_before_tensorify = self.update_labels |
|
elif mode == "cons": |
|
self.collate_fn.update_before_tensorify = ( |
|
self.update_consecutive_span_labels |
|
) |
|
else: |
|
raise ValueError(f"Mode: {mode} not recognizable") |
|
|
|
def transform( |
|
self, |
|
transform_loader: Iterator, |
|
dataset_name: str = None, |
|
**kwargs, |
|
) -> Iterable: |
|
final_data = [] |
|
|
|
for data in transform_loader: |
|
ent_type2ents = defaultdict(set) |
|
for ent in data["ents"]: |
|
ent_type2ents[ent["type"]].add(tuple(ent["index"])) |
|
for ent_type in self.ent_type2query: |
|
gold_ents = ent_type2ents[ent_type] |
|
if ( |
|
len(gold_ents) < 1 |
|
and dataset_name == "train" |
|
and random.random() > self.negative_sample_prob |
|
): |
|
|
|
continue |
|
|
|
query = self.ent_type2query[ent_type] |
|
query_tokens = self.tokenizer.tokenize(query) |
|
try: |
|
res = self.build_ins(query_tokens, data["tokens"], gold_ents) |
|
except (ValueError, AssertionError): |
|
continue |
|
input_tokens, input_ids, mask, offset, available_spans = res |
|
ins = { |
|
"id": data.get("id", str(len(final_data))), |
|
"ent_type": ent_type, |
|
"gold_ents": gold_ents, |
|
"raw_tokens": data["tokens"], |
|
"input_tokens": input_tokens, |
|
"input_ids": input_ids, |
|
"mask": mask, |
|
"offset": offset, |
|
"available_spans": available_spans, |
|
|
|
"labels": None, |
|
|
|
} |
|
final_data.append(ins) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return final_data |
|
|
|
def predict_transform(self, texts: List[str]): |
|
dataset = [] |
|
for text_id, text in enumerate(texts): |
|
data_id = f"Prediction#{text_id}" |
|
tokens = self.tokenizer.tokenize(text) |
|
dataset.append( |
|
{ |
|
"id": data_id, |
|
"tokens": tokens, |
|
"ents": [], |
|
} |
|
) |
|
final_data = self(dataset, disable_pbar=True) |
|
return final_data |
|
|
|
|
|
class CachedPointerMRCTransform(CachedTransformBase, PointerTransformMixin): |
|
def __init__( |
|
self, |
|
max_seq_len: int, |
|
plm_dir: str, |
|
mode: str = "w2", |
|
) -> None: |
|
super().__init__() |
|
|
|
self.max_seq_len: int = max_seq_len |
|
self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) |
|
|
|
self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
|
{ |
|
"input_ids": torch.long, |
|
"mask": torch.long, |
|
"labels": torch.long, |
|
}, |
|
guessing=False, |
|
missing_key_as_null=True, |
|
) |
|
|
|
if mode == "w2": |
|
self.collate_fn.update_before_tensorify = self.update_labels |
|
elif mode == "cons": |
|
self.collate_fn.update_before_tensorify = ( |
|
self.update_consecutive_span_labels |
|
) |
|
else: |
|
raise ValueError(f"Mode: {mode} not recognizable") |
|
|
|
def transform( |
|
self, |
|
transform_loader: Iterator, |
|
dataset_name: str = None, |
|
**kwargs, |
|
) -> Iterable: |
|
final_data = [] |
|
for data in transform_loader: |
|
try: |
|
res = self.build_ins( |
|
data["query_tokens"], |
|
data["context_tokens"], |
|
data["answer_index"], |
|
data.get("background_tokens"), |
|
) |
|
except (ValueError, AssertionError): |
|
continue |
|
input_tokens, input_ids, mask, offset, available_spans = res |
|
ins = { |
|
"id": data.get("id", str(len(final_data))), |
|
"gold_spans": sorted(set(tuple(x) for x in data["answer_index"])), |
|
"raw_tokens": data["context_tokens"], |
|
"input_tokens": input_tokens, |
|
"input_ids": input_ids, |
|
"mask": mask, |
|
"offset": offset, |
|
"available_spans": available_spans, |
|
"labels": None, |
|
} |
|
final_data.append(ins) |
|
|
|
return final_data |
|
|
|
def predict_transform(self, data: list[dict]): |
|
""" |
|
Args: |
|
data: a list of dict with query, context, and background strings |
|
""" |
|
dataset = [] |
|
for idx, ins in enumerate(data): |
|
idx = f"Prediction#{idx}" |
|
dataset.append( |
|
{ |
|
"id": idx, |
|
"query_tokens": list(ins["query"]), |
|
"context_tokens": list(ins["context"]), |
|
"background_tokens": list(ins.get("background")), |
|
"answer_index": [], |
|
} |
|
) |
|
final_data = self(dataset, disable_pbar=True, num_samples=0) |
|
return final_data |
|
|
|
|
|
class CachedLabelPointerTransform(CachedTransformOneBase): |
|
"""Transform for label-token linking for skip consecutive spans""" |
|
|
|
def __init__( |
|
self, |
|
max_seq_len: int, |
|
plm_dir: str, |
|
mode: str = "w2", |
|
label_span: str = "tag", |
|
include_instructions: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.max_seq_len: int = max_seq_len |
|
self.mode = mode |
|
self.label_span = label_span |
|
self.include_instructions = include_instructions |
|
|
|
self.tokenizer: DebertaV2TokenizerFast = DebertaV2TokenizerFast.from_pretrained( |
|
plm_dir |
|
) |
|
self.lc_token = "[LC]" |
|
self.lm_token = "[LM]" |
|
self.lr_token = "[LR]" |
|
self.i_token = "[I]" |
|
self.tl_token = "[TL]" |
|
self.tp_token = "[TP]" |
|
self.b_token = "[B]" |
|
num_added = self.tokenizer.add_tokens( |
|
[ |
|
self.lc_token, |
|
self.lm_token, |
|
self.lr_token, |
|
self.i_token, |
|
self.tl_token, |
|
self.tp_token, |
|
self.b_token, |
|
] |
|
) |
|
assert num_added == 7 |
|
|
|
self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
|
{ |
|
"input_ids": torch.long, |
|
"mask": torch.long, |
|
"labels": torch.long, |
|
"spans": None, |
|
}, |
|
guessing=False, |
|
missing_key_as_null=True, |
|
|
|
discard_missing=False, |
|
) |
|
|
|
self.collate_fn.update_before_tensorify = self.skip_consecutive_span_labels |
|
|
|
def transform(self, instance: dict, **kwargs): |
|
|
|
tokens = [self.tokenizer.cls_token] |
|
mask = [1] |
|
label_map = {"lc": {}, "lm": {}, "lr": {}} |
|
|
|
span_to_label = {} |
|
|
|
def _update_seq( |
|
label: str, |
|
label_type: str, |
|
task: str = "", |
|
label_mask: int = 4, |
|
content_mask: int = 5, |
|
): |
|
if label not in label_map[label_type]: |
|
label_token_map = { |
|
"lc": self.lc_token, |
|
"lm": self.lm_token, |
|
"lr": self.lr_token, |
|
} |
|
label_tag_start_idx = len(tokens) |
|
tokens.append(label_token_map[label_type]) |
|
mask.append(label_mask) |
|
label_tag_end_idx = len(tokens) - 1 |
|
label_tokens = self.tokenizer(label, add_special_tokens=False).tokens() |
|
label_content_start_idx = len(tokens) |
|
tokens.extend(label_tokens) |
|
mask.extend([content_mask] * len(label_tokens)) |
|
label_content_end_idx = len(tokens) - 1 |
|
|
|
if self.label_span == "tag": |
|
start_idx = label_tag_start_idx |
|
end_idx = label_tag_end_idx |
|
elif self.label_span == "content": |
|
start_idx = label_content_start_idx |
|
end_idx = label_content_end_idx |
|
else: |
|
raise ValueError(f"label_span={self.label_span} is not supported") |
|
|
|
if end_idx == start_idx: |
|
label_map[label_type][label] = (start_idx,) |
|
else: |
|
label_map[label_type][label] = (start_idx, end_idx) |
|
span_to_label[label_map[label_type][label]] = { |
|
"type": label_type, |
|
"task": task, |
|
"string": label, |
|
} |
|
return label_map[label_type][label] |
|
|
|
if self.include_instructions: |
|
instruction = instance.get("instruction") |
|
if not instruction: |
|
logger.warning( |
|
"include_instructions=True, while the instruction is empty!" |
|
) |
|
else: |
|
instruction = "" |
|
if instruction: |
|
tokens.append(self.i_token) |
|
mask.append(2) |
|
instruction_tokens = self.tokenizer( |
|
instruction, add_special_tokens=False |
|
).tokens() |
|
tokens.extend(instruction_tokens) |
|
mask.extend([3] * len(instruction_tokens)) |
|
types = instance["schema"].get("cls") |
|
if types: |
|
for t in types: |
|
_update_seq(t, "lc", task="cls") |
|
mention_types = instance["schema"].get("ent") |
|
if mention_types: |
|
for mt in mention_types: |
|
_update_seq(mt, "lm", task="ent") |
|
discon_ent_types = instance["schema"].get("discontinuous_ent") |
|
if discon_ent_types: |
|
for mt in discon_ent_types: |
|
_update_seq(mt, "lm", task="discontinuous_ent") |
|
rel_types = instance["schema"].get("rel") |
|
if rel_types: |
|
for rt in rel_types: |
|
_update_seq(rt, "lr", task="rel") |
|
hyper_rel_schema = instance["schema"].get("hyper_rel") |
|
if hyper_rel_schema: |
|
for rel, qualifiers in hyper_rel_schema.items(): |
|
_update_seq(rel, "lr", task="hyper_rel") |
|
for qualifier in qualifiers: |
|
_update_seq(qualifier, "lr", task="hyper_rel") |
|
event_schema = instance["schema"].get("event") |
|
if event_schema: |
|
for event_type, roles in event_schema.items(): |
|
_update_seq(event_type, "lm", task="event") |
|
for role in roles: |
|
_update_seq(role, "lr", task="event") |
|
|
|
text = instance.get("text") |
|
if text: |
|
text_tokenized = self.tokenizer( |
|
text, return_offsets_mapping=True, add_special_tokens=False |
|
) |
|
if any(val for val in label_map.values()): |
|
text_label_token = self.tl_token |
|
else: |
|
text_label_token = self.tp_token |
|
tokens.append(text_label_token) |
|
mask.append(6) |
|
remain_token_len = self.max_seq_len - 1 - len(tokens) |
|
if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": |
|
return None |
|
text_off = len(tokens) |
|
text_tokens = text_tokenized.tokens()[:remain_token_len] |
|
tokens.extend(text_tokens) |
|
mask.extend([7] * len(text_tokens)) |
|
else: |
|
text_tokenized = None |
|
|
|
bg = instance.get("bg") |
|
if bg: |
|
bg_tokenized = self.tokenizer( |
|
bg, return_offsets_mapping=True, add_special_tokens=False |
|
) |
|
tokens.append(self.b_token) |
|
mask.append(8) |
|
remain_token_len = self.max_seq_len - 1 - len(tokens) |
|
if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": |
|
return None |
|
bg_tokens = bg_tokenized.tokens()[:remain_token_len] |
|
tokens.extend(bg_tokens) |
|
mask.extend([9] * len(bg_tokens)) |
|
else: |
|
bg_tokenized = None |
|
|
|
tokens.append(self.tokenizer.sep_token) |
|
mask.append(10) |
|
|
|
|
|
|
|
spans = [] |
|
if "cls" in instance["ans"]: |
|
for t in instance["ans"]["cls"]: |
|
part = label_map["lc"][t] |
|
spans.append([part]) |
|
if "ent" in instance["ans"]: |
|
for ent in instance["ans"]["ent"]: |
|
label_part = label_map["lm"][ent["type"]] |
|
position_seq = self.char_to_token_span( |
|
ent["span"], text_tokenized, text_off |
|
) |
|
spans.append([label_part, position_seq]) |
|
if "discontinuous_ent" in instance["ans"]: |
|
for ent in instance["ans"]["discontinuous_ent"]: |
|
label_part = label_map["lm"][ent["type"]] |
|
ent_span = [label_part] |
|
for part in ent["span"]: |
|
position_seq = self.char_to_token_span( |
|
part, text_tokenized, text_off |
|
) |
|
ent_span.append(position_seq) |
|
spans.append(ent_span) |
|
if "rel" in instance["ans"]: |
|
for rel in instance["ans"]["rel"]: |
|
label_part = label_map["lr"][rel["relation"]] |
|
head_position_seq = self.char_to_token_span( |
|
rel["head"]["span"], text_tokenized, text_off |
|
) |
|
tail_position_seq = self.char_to_token_span( |
|
rel["tail"]["span"], text_tokenized, text_off |
|
) |
|
spans.append([label_part, head_position_seq, tail_position_seq]) |
|
if "hyper_rel" in instance["ans"]: |
|
for rel in instance["ans"]["hyper_rel"]: |
|
label_part = label_map["lr"][rel["relation"]] |
|
head_position_seq = self.char_to_token_span( |
|
rel["head"]["span"], text_tokenized, text_off |
|
) |
|
tail_position_seq = self.char_to_token_span( |
|
rel["tail"]["span"], text_tokenized, text_off |
|
) |
|
|
|
for q in rel["qualifiers"]: |
|
q_label_part = label_map["lr"][q["label"]] |
|
q_position_seq = self.char_to_token_span( |
|
q["span"], text_tokenized, text_off |
|
) |
|
spans.append( |
|
[ |
|
label_part, |
|
head_position_seq, |
|
tail_position_seq, |
|
q_label_part, |
|
q_position_seq, |
|
] |
|
) |
|
if "event" in instance["ans"]: |
|
for event in instance["ans"]["event"]: |
|
event_type_label_part = label_map["lm"][event["event_type"]] |
|
trigger_position_seq = self.char_to_token_span( |
|
event["trigger"]["span"], text_tokenized, text_off |
|
) |
|
trigger_part = [event_type_label_part, trigger_position_seq] |
|
spans.append(trigger_part) |
|
for arg in event["args"]: |
|
role_label_part = label_map["lr"][arg["role"]] |
|
arg_position_seq = self.char_to_token_span( |
|
arg["span"], text_tokenized, text_off |
|
) |
|
arg_part = [role_label_part, trigger_position_seq, arg_position_seq] |
|
spans.append(arg_part) |
|
if "span" in instance["ans"]: |
|
|
|
for span in instance["ans"]["span"]: |
|
span_position_seq = self.char_to_token_span( |
|
span["span"], text_tokenized, text_off |
|
) |
|
spans.append([span_position_seq]) |
|
|
|
if self.mode == "w2": |
|
new_spans = [] |
|
for parts in spans: |
|
new_parts = [] |
|
for part in parts: |
|
new_parts.append(tuple(range(part[0], part[-1] + 1))) |
|
new_spans.append(new_parts) |
|
spans = new_spans |
|
elif self.mode == "span": |
|
spans = spans |
|
else: |
|
raise ValueError(f"mode={self.mode} is not supported") |
|
|
|
ins = { |
|
"raw": instance, |
|
"tokens": tokens, |
|
"input_ids": self.tokenizer.convert_tokens_to_ids(tokens), |
|
"mask": mask, |
|
"spans": spans, |
|
"label_map": label_map, |
|
"span_to_label": span_to_label, |
|
"labels": None, |
|
} |
|
return ins |
|
|
|
def char_to_token_span( |
|
self, span: list[int], tokenized: BatchEncoding, offset: int = 0 |
|
) -> list[int]: |
|
token_s = tokenized.char_to_token(span[0]) |
|
token_e = tokenized.char_to_token(span[1] - 1) |
|
if token_e == token_s: |
|
position_seq = (offset + token_s,) |
|
else: |
|
position_seq = (offset + token_s, offset + token_e) |
|
return position_seq |
|
|
|
def skip_consecutive_span_labels(self, data: dict) -> dict: |
|
bs = len(data["input_ids"]) |
|
max_seq_len = max(len(input_ids) for input_ids in data["input_ids"]) |
|
batch_seq_len = min(self.max_seq_len, max_seq_len) |
|
for i in range(bs): |
|
data["input_ids"][i] = data["input_ids"][i][:batch_seq_len] |
|
data["mask"][i] = data["mask"][i][:batch_seq_len] |
|
assert len(data["input_ids"][i]) == len(data["mask"][i]) |
|
pad_len = batch_seq_len - len(data["mask"][i]) |
|
data["input_ids"][i] = ( |
|
data["input_ids"][i] + [self.tokenizer.pad_token_id] * pad_len |
|
) |
|
data["mask"][i] = data["mask"][i] + [0] * pad_len |
|
data["labels"][i] = encode_nnw_nsw_thw_mat(data["spans"][i], batch_seq_len) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return data |
|
|