riccorl's picture
first commit
626eca0
raw
history blame
No virus
9.63 kB
import collections
import itertools
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
from relik.reader.data.relik_reader_sample import RelikReaderSample
@dataclass
class Window:
doc_id: int
window_id: int
text: str
tokens: List[str]
doc_topic: Optional[str]
offset: int
token2char_start: dict
token2char_end: dict
window_candidates: Optional[List[str]] = None
class WindowManager:
def __init__(self, tokenizer: BaseTokenizer) -> None:
self.tokenizer = tokenizer
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
tokenized_document = self.tokenizer(document)
tokens = []
tokens_char_mapping = []
for token in tokenized_document:
tokens.append(token.text)
tokens_char_mapping.append((token.start_char, token.end_char))
return tokens, tokens_char_mapping
def create_windows(
self,
document: str,
window_size: int,
stride: int,
doc_id: int = 0,
doc_topic: str = None,
) -> List[RelikReaderSample]:
document_tokens, tokens_char_mapping = self.tokenize(document)
if doc_topic is None:
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
document_windows = []
if len(document_tokens) <= window_size:
text = document
# relik_reader_sample = RelikReaderSample()
document_windows.append(
# Window(
RelikReaderSample(
doc_id=doc_id,
window_id=0,
text=text,
tokens=document_tokens,
doc_topic=doc_topic,
offset=0,
token2char_start={
str(i): tokens_char_mapping[i][0]
for i in range(len(document_tokens))
},
token2char_end={
str(i): tokens_char_mapping[i][1]
for i in range(len(document_tokens))
},
)
)
else:
for window_id, i in enumerate(range(0, len(document_tokens), stride)):
# if the last stride is smaller than the window size, then we can
# include more tokens form the previous window.
if i != 0 and i + window_size > len(document_tokens):
overflowing_tokens = i + window_size - len(document_tokens)
if overflowing_tokens >= stride:
break
i -= overflowing_tokens
involved_token_indices = list(
range(i, min(i + window_size, len(document_tokens) - 1))
)
window_tokens = [document_tokens[j] for j in involved_token_indices]
window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
text = document[window_text_start:window_text_end]
document_windows.append(
# Window(
RelikReaderSample(
# dict(
doc_id=doc_id,
window_id=window_id,
text=text,
tokens=window_tokens,
doc_topic=doc_topic,
offset=window_text_start,
token2char_start={
str(i): tokens_char_mapping[ti][0]
for i, ti in enumerate(involved_token_indices)
},
token2char_end={
str(i): tokens_char_mapping[ti][1]
for i, ti in enumerate(involved_token_indices)
},
# )
)
)
return document_windows
def merge_windows(
self, windows: List[RelikReaderSample]
) -> List[RelikReaderSample]:
windows_by_doc_id = collections.defaultdict(list)
for window in windows:
windows_by_doc_id[window.doc_id].append(window)
merged_window_by_doc = {
doc_id: self.merge_doc_windows(doc_windows)
for doc_id, doc_windows in windows_by_doc_id.items()
}
return list(merged_window_by_doc.values())
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
if len(windows) == 1:
return windows[0]
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
windows = sorted(windows, key=(lambda x: x.offset))
window_accumulator = windows[0]
for next_window in windows[1:]:
window_accumulator = self._merge_window_pair(
window_accumulator, next_window
)
return window_accumulator
def _merge_tokens(
self, window1: RelikReaderSample, window2: RelikReaderSample
) -> Tuple[list, dict, dict]:
w1_tokens = window1.tokens[1:-1]
w2_tokens = window2.tokens[1:-1]
# find intersection
tokens_intersection = None
for k in reversed(range(1, len(w1_tokens))):
if w1_tokens[-k:] == w2_tokens[:k]:
tokens_intersection = k
break
assert tokens_intersection is not None, (
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
+ f"w1 tokens: {w1_tokens}\n"
+ f"w2 tokens: {w2_tokens}\n"
)
final_tokens = (
[window1.tokens[0]] # CLS
+ w1_tokens
+ w2_tokens[tokens_intersection:]
+ [window1.tokens[-1]] # SEP
)
w2_starting_offset = len(w1_tokens) - tokens_intersection
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
final_t2c = dict()
final_t2c.update(t2c1)
for t, c in t2c2.items():
t = int(t)
if t < tokens_intersection:
continue
final_t2c[str(t + w2_starting_offset)] = c
return final_t2c
return (
final_tokens,
merge_char_mapping(window1.token2char_start, window2.token2char_start),
merge_char_mapping(window1.token2char_end, window2.token2char_end),
)
def _merge_span_annotation(
self, span_annotation1: List[list], span_annotation2: List[list]
) -> List[list]:
uniq_store = set()
final_span_annotation_store = []
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
span_annotation_id = tuple(span_annotation)
if span_annotation_id not in uniq_store:
uniq_store.add(span_annotation_id)
final_span_annotation_store.append(span_annotation)
return sorted(final_span_annotation_store, key=lambda x: x[0])
def _merge_predictions(
self,
window1: RelikReaderSample,
window2: RelikReaderSample,
) -> Tuple[Set[Tuple[int, int, str]], dict]:
merged_predictions = window1.predicted_window_labels_chars.union(
window2.predicted_window_labels_chars
)
span_title_probabilities = dict()
# probabilities
for span_prediction, predicted_probs in itertools.chain(
window1.probs_window_labels_chars.items(),
window2.probs_window_labels_chars.items(),
):
if span_prediction not in span_title_probabilities:
span_title_probabilities[span_prediction] = predicted_probs
return merged_predictions, span_title_probabilities
def _merge_window_pair(
self,
window1: RelikReaderSample,
window2: RelikReaderSample,
) -> RelikReaderSample:
merging_output = dict()
if getattr(window1, "doc_id", None) is not None:
assert window1.doc_id == window2.doc_id
if getattr(window1, "offset", None) is not None:
assert (
window1.offset < window2.offset
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
merging_output["doc_id"] = window1.doc_id
merging_output["offset"] = window2.offset
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
window1, window2
)
window_labels = None
if getattr(window1, "window_labels", None) is not None:
window_labels = self._merge_span_annotation(
window1.window_labels, window2.window_labels
)
(
predicted_window_labels_chars,
probs_window_labels_chars,
) = self._merge_predictions(
window1,
window2,
)
merging_output.update(
dict(
tokens=m_tokens,
token2char_start=m_token2char_start,
token2char_end=m_token2char_end,
window_labels=window_labels,
predicted_window_labels_chars=predicted_window_labels_chars,
probs_window_labels_chars=probs_window_labels_chars,
)
)
return RelikReaderSample(**merging_output)