File size: 5,951 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import logging
from typing import Iterable, Iterator, List, Optional

import hydra
import torch
from lightning.pytorch.utilities import move_data_to_device
from torch.utils.data import DataLoader
from tqdm import tqdm

from relik.reader.data.patches import merge_patches_predictions
from relik.reader.data.relik_reader_sample import (
    RelikReaderSample,
    load_relik_reader_samples,
)
from relik.reader.relik_reader_core import RelikReaderCoreModel
from relik.reader.utils.special_symbols import NME_SYMBOL

logger = logging.getLogger(__name__)


def convert_tokens_to_char_annotations(
    sample: RelikReaderSample, remove_nmes: bool = False
):
    char_annotations = set()

    for (
        predicted_entity,
        predicted_spans,
    ) in sample.predicted_window_labels.items():
        if predicted_entity == NME_SYMBOL and remove_nmes:
            continue

        for span_start, span_end in predicted_spans:
            span_start = sample.token2char_start[str(span_start)]
            span_end = sample.token2char_end[str(span_end)]

            char_annotations.add((span_start, span_end, predicted_entity))

    char_probs_annotations = dict()
    for (
        span_start,
        span_end,
    ), candidates_probs in sample.span_title_probabilities.items():
        span_start = sample.token2char_start[str(span_start)]
        span_end = sample.token2char_end[str(span_end)]
        char_probs_annotations[(span_start, span_end)] = {
            title for title, _ in candidates_probs
        }

    sample.predicted_window_labels_chars = char_annotations
    sample.probs_window_labels_chars = char_probs_annotations


class RelikReaderPredictor:
    def __init__(
        self,
        relik_reader_core: RelikReaderCoreModel,
        dataset_conf: Optional[dict] = None,
        predict_nmes: bool = False,
    ) -> None:
        self.relik_reader_core = relik_reader_core
        self.dataset_conf = dataset_conf
        self.predict_nmes = predict_nmes

        if self.dataset_conf is not None:
            # instantiate dataset
            self.dataset = hydra.utils.instantiate(
                dataset_conf,
                dataset_path=None,
                samples=None,
            )

    def predict(
        self,
        path: Optional[str],
        samples: Optional[Iterable[RelikReaderSample]],
        dataset_conf: Optional[dict],
        token_batch_size: int = 1024,
        progress_bar: bool = False,
        **kwargs,
    ) -> List[RelikReaderSample]:
        annotated_samples = list(
            self._predict(path, samples, dataset_conf, token_batch_size, progress_bar)
        )
        for sample in annotated_samples:
            merge_patches_predictions(sample)
            convert_tokens_to_char_annotations(
                sample, remove_nmes=not self.predict_nmes
            )
        return annotated_samples

    def _predict(
        self,
        path: Optional[str],
        samples: Optional[Iterable[RelikReaderSample]],
        dataset_conf: dict,
        token_batch_size: int = 1024,
        progress_bar: bool = False,
        **kwargs,
    ) -> Iterator[RelikReaderSample]:
        assert (
            path is not None or samples is not None
        ), "Either predict on a path or on an iterable of samples"

        samples = load_relik_reader_samples(path) if samples is None else samples

        # setup infrastructure to re-yield in order
        def samples_it():
            for i, sample in enumerate(samples):
                assert sample._mixin_prediction_position is None
                sample._mixin_prediction_position = i
                yield sample

        next_prediction_position = 0
        position2predicted_sample = {}

        # instantiate dataset
        if getattr(self, "dataset", None) is not None:
            dataset = self.dataset
            dataset.samples = samples_it()
            dataset.tokens_per_batch = token_batch_size
        else:
            dataset = hydra.utils.instantiate(
                dataset_conf,
                dataset_path=None,
                samples=samples_it(),
                tokens_per_batch=token_batch_size,
            )

        # instantiate dataloader
        iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
        if progress_bar:
            iterator = tqdm(iterator, desc="Predicting")

        model_device = next(self.relik_reader_core.parameters()).device

        with torch.inference_mode():
            for batch in iterator:
                # do batch predict
                with torch.autocast(
                    "cpu" if model_device == torch.device("cpu") else "cuda"
                ):
                    batch = move_data_to_device(batch, model_device)
                    batch_out = self.relik_reader_core.batch_predict(**batch)
                # update prediction position position
                for sample in batch_out:
                    if sample._mixin_prediction_position >= next_prediction_position:
                        position2predicted_sample[
                            sample._mixin_prediction_position
                        ] = sample

                # yield
                while next_prediction_position in position2predicted_sample:
                    yield position2predicted_sample[next_prediction_position]
                    del position2predicted_sample[next_prediction_position]
                    next_prediction_position += 1

        if len(position2predicted_sample) > 0:
            logger.warning(
                "It seems samples have been discarded in your dataset. "
                "This means that you WON'T have a prediction for each input sample. "
                "Prediction order will also be partially disrupted"
            )
            for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]):
                yield v

        if progress_bar:
            iterator.close()