File size: 17,534 Bytes
914502f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import argparse
import json
from pathlib import Path
import re
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
from modules.layers import LstmSeq2SeqEncoder
from modules.base import InstructBase
from modules.evaluator import Evaluator, greedy_search
from modules.span_rep import SpanRepLayer
from modules.token_rep import TokenRepLayer
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import HfHubHTTPError



class GLiNER(InstructBase, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__(config)

        self.config = config

        # [ENT] token
        self.entity_token = "<<ENT>>"
        self.sep_token = "<<SEP>>"

        # usually a pretrained bidirectional transformer, returns first subtoken representation
        self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
                                             subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
                                             add_tokens=[self.entity_token, self.sep_token])

        # hierarchical representation of tokens
        self.rnn = LstmSeq2SeqEncoder(
            input_size=config.hidden_size,
            hidden_size=config.hidden_size // 2,
            num_layers=1,
            bidirectional=True,
        )

        # span representation
        self.span_rep_layer = SpanRepLayer(
            span_mode=config.span_mode,
            hidden_size=config.hidden_size,
            max_width=config.max_width,
            dropout=config.dropout,
        )

        # prompt representation (FFN)
        self.prompt_rep_layer = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size * 4),
            nn.Dropout(config.dropout),
            nn.ReLU(),
            nn.Linear(config.hidden_size * 4, config.hidden_size)
        )

    def compute_score_train(self, x):
        span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)

        new_length = x['seq_length'].clone()
        new_tokens = []
        all_len_prompt = []
        num_classes_all = []

        # add prompt to the tokens
        for i in range(len(x['tokens'])):
            all_types_i = list(x['classes_to_id'][i].keys())
            # multiple entity types in all_types. Prompt is appended at the start of tokens
            entity_prompt = []
            num_classes_all.append(len(all_types_i))
            # add enity types to prompt
            for entity_type in all_types_i:
                entity_prompt.append(self.entity_token)  # [ENT] token
                entity_prompt.append(entity_type)  # entity type
            entity_prompt.append(self.sep_token)  # [SEP] token

            # prompt format:
            # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]

            # add prompt to the tokens
            tokens_p = entity_prompt + x['tokens'][i]

            # input format:
            # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n

            # update length of the sequence (add prompt length to the original length)
            new_length[i] = new_length[i] + len(entity_prompt)
            # update tokens
            new_tokens.append(tokens_p)
            # store prompt length
            all_len_prompt.append(len(entity_prompt))

        # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
        max_num_classes = max(num_classes_all)
        entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
            x['span_mask'].device)
        entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
            x['span_mask'].device)  # [batch_size, max_num_classes]

        # compute all token representations
        bert_output = self.token_rep_layer(new_tokens, new_length)
        word_rep_w_prompt = bert_output["embeddings"]  # embeddings for all tokens (with prompt)
        mask_w_prompt = bert_output["mask"]  # mask for all tokens (with prompt)

        # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
        word_rep = []  # word representation (after [SEP])
        mask = []  # mask (after [SEP])
        entity_type_rep = []  # entity type representation (before [SEP])
        for i in range(len(x['tokens'])):
            prompt_entity_length = all_len_prompt[i]  # length of prompt for this example
            # get word representation (after [SEP])
            word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
            # get mask (after [SEP])
            mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])

            # get entity type representation (before [SEP])
            entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1]  # remove [SEP]
            entity_rep = entity_rep[0::2]  # it means that we take every second element starting from the second one
            entity_type_rep.append(entity_rep)

        # padding for word_rep, mask and entity_type_rep
        word_rep = pad_sequence(word_rep, batch_first=True)  # [batch_size, seq_len, hidden_size]
        mask = pad_sequence(mask, batch_first=True)  # [batch_size, seq_len]
        entity_type_rep = pad_sequence(entity_type_rep, batch_first=True)  # [batch_size, len_types, hidden_size]

        # compute span representation
        word_rep = self.rnn(word_rep, mask)
        span_rep = self.span_rep_layer(word_rep, span_idx)

        # compute final entity type representation (FFN)
        entity_type_rep = self.prompt_rep_layer(entity_type_rep)  # (batch_size, len_types, hidden_size)
        num_classes = entity_type_rep.shape[1]  # number of entity types

        # similarity score
        scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)

        return scores, num_classes, entity_type_mask

    def forward(self, x):
        # compute span representation
        scores, num_classes, entity_type_mask = self.compute_score_train(x)
        batch_size = scores.shape[0]

        # loss for filtering classifier
        logits_label = scores.view(-1, num_classes)
        labels = x["span_label"].view(-1)  # (batch_size * num_spans)
        mask_label = labels != -1  # (batch_size * num_spans)
        labels.masked_fill_(~mask_label, 0)  # Set the labels of padding tokens to 0

        # one-hot encoding
        labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
        labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)  # Set the corresponding index to 1
        labels_one_hot = labels_one_hot[:, 1:]  # Remove the first column
        # Shape of labels_one_hot: (batch_size * num_spans, num_classes)

        # compute loss (without reduction)
        all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
                                                        reduction='none')
        # mask loss using entity_type_mask (B, C)
        masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
        all_losses = masked_loss.view(-1, num_classes)
        # expand mask_label to all_losses
        mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
        # put lower loss for in label_one_hot (2 for positive, 1 for negative)
        weight_c = labels_one_hot + 1
        # apply mask
        all_losses = all_losses * mask_label.float() * weight_c
        return all_losses.sum()

    def compute_score_eval(self, x, device):
        # check if classes_to_id is dict
        assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"

        span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)

        all_types = list(x['classes_to_id'].keys())
        # multiple entity types in all_types. Prompt is appended at the start of tokens
        entity_prompt = []

        # add enity types to prompt
        for entity_type in all_types:
            entity_prompt.append(self.entity_token)
            entity_prompt.append(entity_type)

        entity_prompt.append(self.sep_token)

        prompt_entity_length = len(entity_prompt)

        # add prompt
        tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
        seq_length_p = x['seq_length'] + prompt_entity_length

        out = self.token_rep_layer(tokens_p, seq_length_p)

        word_rep_w_prompt = out["embeddings"]
        mask_w_prompt = out["mask"]

        # remove prompt
        word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
        mask = mask_w_prompt[:, prompt_entity_length:]

        # get_entity_type_rep
        entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
        # extract [ENT] tokens (which are at even positions in entity_type_rep)
        entity_type_rep = entity_type_rep[:, 0::2, :]

        entity_type_rep = self.prompt_rep_layer(entity_type_rep)  # (batch_size, len_types, hidden_size)

        word_rep = self.rnn(word_rep, mask)

        span_rep = self.span_rep_layer(word_rep, span_idx)

        local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)

        return local_scores

    @torch.no_grad()
    def predict(self, x, flat_ner=False, threshold=0.5):
        self.eval()
        local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
        spans = []
        for i, _ in enumerate(x["tokens"]):
            local_i = local_scores[i]
            wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
            span_i = []
            for s, k, c in zip(*wh_i):
                if s + k < len(x["tokens"][i]):
                    span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
            span_i = greedy_search(span_i, flat_ner)
            spans.append(span_i)
        return spans

    def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
        tokens = []
        start_token_idx_to_text_idx = []
        end_token_idx_to_text_idx = []
        for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
            tokens.append(match.group())
            start_token_idx_to_text_idx.append(match.start())
            end_token_idx_to_text_idx.append(match.end())

        input_x = {"tokenized_text": tokens, "ner": None}
        x = self.collate_fn([input_x], labels)
        output = self.predict(x, flat_ner=flat_ner, threshold=threshold)

        entities = []
        for start_token_idx, end_token_idx, ent_type in output[0]:
            start_text_idx = start_token_idx_to_text_idx[start_token_idx]
            end_text_idx = end_token_idx_to_text_idx[end_token_idx]
            entities.append({
                "start": start_token_idx_to_text_idx[start_token_idx],
                "end": end_token_idx_to_text_idx[end_token_idx],
                "text": text[start_text_idx:end_text_idx],
                "label": ent_type,
            })
        return entities

    def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
        self.eval()
        data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
        device = next(self.parameters()).device
        all_preds = []
        all_trues = []
        for x in data_loader:
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            batch_predictions = self.predict(x, flat_ner, threshold)
            all_preds.extend(batch_predictions)
            all_trues.extend(x["entities"])
        evaluator = Evaluator(all_trues, all_preds)
        out, f1 = evaluator.evaluate()
        return out, f1

    @classmethod
    def _from_pretrained(
        cls,
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[Union[str, Path]],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: bool,
        local_files_only: bool,
        token: Union[str, bool, None],
        map_location: str = "cpu",
        strict: bool = False,
        **model_kwargs,
    ):
        # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
        filenames = ["gliner_base.pt", "gliner_multi.pt"]
        for filename in filenames:
            model_file = Path(model_id) / filename
            if not model_file.exists():
                try:
                    model_file = hf_hub_download(
                        repo_id=model_id,
                        filename=filename,
                        revision=revision,
                        cache_dir=cache_dir,
                        force_download=force_download,
                        proxies=proxies,
                        resume_download=resume_download,
                        token=token,
                        local_files_only=local_files_only,
                    )
                except HfHubHTTPError:
                    continue
            dict_load = torch.load(model_file, map_location=torch.device(map_location))
            config = dict_load["config"]
            state_dict = dict_load["model_weights"]
            config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
            model = cls(config)
            model.load_state_dict(state_dict, strict=strict, assign=True)
            # Required to update flair's internals as well:
            model.to(map_location)
            return model

        # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
        from train import load_config_as_namespace

        model_file = Path(model_id) / "pytorch_model.bin"
        if not model_file.exists():
            model_file = hf_hub_download(
                repo_id=model_id,
                filename="pytorch_model.bin",
                revision=revision,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                token=token,
                local_files_only=local_files_only,
            )
        config_file = Path(model_id) / "gliner_config.json"
        if not config_file.exists():
            config_file = hf_hub_download(
                repo_id=model_id,
                filename="gliner_config.json",
                revision=revision,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                token=token,
                local_files_only=local_files_only,
            )
        config = load_config_as_namespace(config_file)
        model = cls(config)
        state_dict = torch.load(model_file, map_location=torch.device(map_location))
        model.load_state_dict(state_dict, strict=strict, assign=True)
        model.to(map_location)
        return model

    def save_pretrained(
        self,
        save_directory: Union[str, Path],
        *,
        config: Optional[Union[dict, "DataclassInstance"]] = None,
        repo_id: Optional[str] = None,
        push_to_hub: bool = False,
        **push_to_hub_kwargs,
    ) -> Optional[str]:
        """
        Save weights in local directory.

        Args:
            save_directory (`str` or `Path`):
                Path to directory in which the model weights and configuration will be saved.
            config (`dict` or `DataclassInstance`, *optional*):
                Model configuration specified as a key/value dictionary or a dataclass instance.
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Huggingface Hub after saving it.
            repo_id (`str`, *optional*):
                ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
                not provided.
            kwargs:
                Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
        """
        save_directory = Path(save_directory)
        save_directory.mkdir(parents=True, exist_ok=True)

        # save model weights/files
        torch.save(self.state_dict(), save_directory / "pytorch_model.bin")

        # save config (if provided)
        if config is None:
            config = self.config
        if config is not None:
            if isinstance(config, argparse.Namespace):
                config = vars(config)
            (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))

        # push to the Hub if required
        if push_to_hub:
            kwargs = push_to_hub_kwargs.copy()  # soft-copy to avoid mutating input
            if config is not None:  # kwarg for `push_to_hub`
                kwargs["config"] = config
            if repo_id is None:
                repo_id = save_directory.name  # Defaults to `save_directory` name
            return self.push_to_hub(repo_id=repo_id, **kwargs)
        return None

    def to(self, device):
        super().to(device)
        import flair

        flair.device = device
        return self