riccorl's picture
first commit
626eca0
from typing import List
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.utils.special_symbols import NME_SYMBOL
def merge_patches_predictions(sample) -> None:
sample._d["predicted_window_labels"] = dict()
predicted_window_labels = sample._d["predicted_window_labels"]
sample._d["span_title_probabilities"] = dict()
span_title_probabilities = sample._d["span_title_probabilities"]
span2title = dict()
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
# selecting span predictions
for predicted_title, predicted_spans in patch_info[
"predicted_window_labels"
].items():
for pred_span in predicted_spans:
pred_span = tuple(pred_span)
curr_title = span2title.get(pred_span)
if curr_title is None or curr_title == NME_SYMBOL:
span2title[pred_span] = predicted_title
# else:
# print("Merging at patch level")
# selecting span predictions probability
for predicted_span, titles_probabilities in patch_info[
"span_title_probabilities"
].items():
if predicted_span not in span_title_probabilities:
span_title_probabilities[predicted_span] = titles_probabilities
for span, title in span2title.items():
if title not in predicted_window_labels:
predicted_window_labels[title] = list()
predicted_window_labels[title].append(span)
def remove_duplicate_samples(
samples: List[RelikReaderSample],
) -> List[RelikReaderSample]:
seen_sample = set()
samples_store = []
for sample in samples:
sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
if sample_id not in seen_sample:
seen_sample.add(sample_id)
samples_store.append(sample)
return samples_store