kargaranamir commited on
Commit
a220ee8
1 Parent(s): 62b1636

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +490 -0
utils.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils for Hengam inference
3
+ """
4
+
5
+ """### Import Libraries"""
6
+
7
+ # import primitive libraries
8
+ import os
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import json
13
+
14
+ # import seqval to report classifier performance metrics
15
+ from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score
16
+ from seqeval.scheme import IOB2
17
+
18
+ # import torch related modules
19
+ import torch
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.data import Dataset
22
+ from torch.nn.utils.rnn import pad_sequence
23
+ import torch.nn as nn
24
+
25
+ # import pytorch lightning library
26
+ import pytorch_lightning as pl
27
+ from torchcrf import CRF as SUPERCRF
28
+
29
+ # import NLTK to create better tokenizer
30
+ import nltk
31
+ from nltk.tokenize import RegexpTokenizer
32
+
33
+ # Transformers : Roberta Model
34
+ from transformers import XLMRobertaTokenizerFast
35
+ from transformers import XLMRobertaModel, XLMRobertaConfig
36
+
37
+
38
+ # import Typings
39
+ from typing import Union, Dict, List, Tuple, Any, Optional
40
+
41
+ import glob
42
+
43
+ # for sent tokenizer (nltk)
44
+ nltk.download('punkt')
45
+
46
+
47
+ """## XLM-Roberta
48
+ ### TokenFromSubtoken
49
+ - Code adapted from the following [file](https://github.com/deepmipt/DeepPavlov/blob/master/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py)
50
+ - DeepPavlov is an popular open source library for deep learning end-to-end dialog systems and chatbots.
51
+ - Licensed under the Apache License, Version 2.0 (the "License");
52
+ """
53
+
54
+ class TokenFromSubtoken(torch.nn.Module):
55
+
56
+ def forward(self, units: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
57
+ """ Assemble token level units from subtoken level units
58
+ Args:
59
+ units: torch.Tensor of shape [batch_size, SUBTOKEN_seq_length, n_features]
60
+ mask: mask of token beginnings. For example: for tokens
61
+ [[``[CLS]`` ``My``, ``capybara``, ``[SEP]``],
62
+ [``[CLS]`` ``Your``, ``aar``, ``##dvark``, ``is``, ``awesome``, ``[SEP]``]]
63
+ the mask will be
64
+ [[0, 1, 1, 0, 0, 0, 0],
65
+ [0, 1, 1, 0, 1, 1, 0]]
66
+ Returns:
67
+ word_level_units: Units assembled from ones in the mask. For the
68
+ example above this units will correspond to the following
69
+ [[``My``, ``capybara``],
70
+ [``Your`, ``aar``, ``is``, ``awesome``,]]
71
+ the shape of this tensor will be [batch_size, TOKEN_seq_length, n_features]
72
+ """
73
+
74
+ device = units.device
75
+ nf_int = units.size()[-1]
76
+ batch_size = units.size()[0]
77
+
78
+ # number of TOKENS in each sentence
79
+ token_seq_lengths = torch.sum(mask, 1).to(torch.int64)
80
+ # number of words
81
+ n_words = torch.sum(token_seq_lengths)
82
+ # max token seq len
83
+ max_token_seq_len = torch.max(token_seq_lengths)
84
+
85
+ idxs = torch.stack(torch.nonzero(mask, as_tuple=True), dim=1)
86
+ # padding is for computing change from one sample to another in the batch
87
+ sample_ids_in_batch = torch.nn.functional.pad(input=idxs[:, 0], pad=[1, 0])
88
+
89
+ a = (~torch.eq(sample_ids_in_batch[1:], sample_ids_in_batch[:-1])).to(torch.int64)
90
+
91
+ # transforming sample start masks to the sample starts themselves
92
+ q = a * torch.arange(n_words, device=device).to(torch.int64)
93
+ count_to_substract = torch.nn.functional.pad(torch.masked_select(q, q.to(torch.bool)), [1, 0])
94
+
95
+ new_word_indices = torch.arange(n_words, device=device).to(torch.int64) - count_to_substract[torch.cumsum(a, 0)]
96
+
97
+ n_total_word_elements = max_token_seq_len*torch.ones_like(token_seq_lengths, device=device).sum()
98
+ word_indices_flat = (idxs[:, 0] * max_token_seq_len + new_word_indices).to(torch.int64)
99
+ #x_mask = torch.sum(torch.nn.functional.one_hot(word_indices_flat, n_total_word_elements), 0)
100
+ #x_mask = x_mask.to(torch.bool)
101
+ x_mask = torch.zeros(n_total_word_elements, dtype=torch.bool, device=device)
102
+ x_mask[word_indices_flat] = torch.ones_like(word_indices_flat, device=device, dtype=torch.bool)
103
+ # to get absolute indices we add max_token_seq_len:
104
+ # idxs[:, 0] * max_token_seq_len -> [0, 0, 0, 1, 1, 2] * 2 = [0, 0, 0, 3, 3, 6]
105
+ # word_indices_flat -> [0, 0, 0, 3, 3, 6] + [0, 1, 2, 0, 1, 0] = [0, 1, 2, 3, 4, 6]
106
+ # total number of words in the batch (including paddings)
107
+ # batch_size * max_token_seq_len -> 3 * 3 = 9
108
+ # tf.one_hot(...) ->
109
+ # [[1. 0. 0. 0. 0. 0. 0. 0. 0.]
110
+ # [0. 1. 0. 0. 0. 0. 0. 0. 0.]
111
+ # [0. 0. 1. 0. 0. 0. 0. 0. 0.]
112
+ # [0. 0. 0. 1. 0. 0. 0. 0. 0.]
113
+ # [0. 0. 0. 0. 1. 0. 0. 0. 0.]
114
+ # [0. 0. 0. 0. 0. 0. 1. 0. 0.]]
115
+ # x_mask -> [1, 1, 1, 1, 1, 0, 1, 0, 0]
116
+ nonword_indices_flat = (~x_mask).nonzero().squeeze(-1)
117
+
118
+ # get a sequence of units corresponding to the start subtokens of the words
119
+ # size: [n_words, n_features]
120
+
121
+ elements = units[mask.bool()]
122
+
123
+ # prepare zeros for paddings
124
+ # size: [batch_size * TOKEN_seq_length - n_words, n_features]
125
+ paddings = torch.zeros_like(nonword_indices_flat, dtype=elements.dtype).unsqueeze(-1).repeat(1,nf_int).to(device)
126
+
127
+ # tensor_flat -> [x, x, x, x, x, 0, x, 0, 0]
128
+ tensor_flat_unordered = torch.cat([elements, paddings])
129
+ _, order_idx = torch.sort(torch.cat([word_indices_flat, nonword_indices_flat]))
130
+ tensor_flat = tensor_flat_unordered[order_idx]
131
+
132
+ tensor = torch.reshape(tensor_flat, (-1, max_token_seq_len, nf_int))
133
+ # tensor -> [[x, x, x],
134
+ # [x, x, 0],
135
+ # [x, 0, 0]]
136
+
137
+ return tensor
138
+
139
+ """### Conditional Random Field
140
+ - Code adopted form [torchcrf library](https://pytorch-crf.readthedocs.io/en/stable/)
141
+ - we override veiterbi decoder in order to make it compatible with our code
142
+ """
143
+
144
+ class CRF(SUPERCRF):
145
+
146
+ # override veiterbi decoder in order to make it compatible with our code
147
+ def _viterbi_decode(self, emissions: torch.FloatTensor,
148
+ mask: torch.ByteTensor) -> List[List[int]]:
149
+ # emissions: (seq_length, batch_size, num_tags)
150
+ # mask: (seq_length, batch_size)
151
+ assert emissions.dim() == 3 and mask.dim() == 2
152
+ assert emissions.shape[:2] == mask.shape
153
+ assert emissions.size(2) == self.num_tags
154
+ assert mask[0].all()
155
+
156
+ seq_length, batch_size = mask.shape
157
+
158
+ # Start transition and first emission
159
+ # shape: (batch_size, num_tags)
160
+ score = self.start_transitions + emissions[0]
161
+ history = []
162
+
163
+ # score is a tensor of size (batch_size, num_tags) where for every batch,
164
+ # value at column j stores the score of the best tag sequence so far that ends
165
+ # with tag j
166
+ # history saves where the best tags candidate transitioned from; this is used
167
+ # when we trace back the best tag sequence
168
+
169
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
170
+ # for every possible next tag
171
+ for i in range(1, seq_length):
172
+ # Broadcast viterbi score for every possible next tag
173
+ # shape: (batch_size, num_tags, 1)
174
+ broadcast_score = score.unsqueeze(2)
175
+
176
+ # Broadcast emission score for every possible current tag
177
+ # shape: (batch_size, 1, num_tags)
178
+ broadcast_emission = emissions[i].unsqueeze(1)
179
+
180
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
181
+ # for each sample, entry at row i and column j stores the score of the best
182
+ # tag sequence so far that ends with transitioning from tag i to tag j and emitting
183
+ # shape: (batch_size, num_tags, num_tags)
184
+ next_score = broadcast_score + self.transitions + broadcast_emission
185
+
186
+ # Find the maximum score over all possible current tag
187
+ # shape: (batch_size, num_tags)
188
+ next_score, indices = next_score.max(dim=1)
189
+
190
+ # Set score to the next score if this timestep is valid (mask == 1)
191
+ # and save the index that produces the next score
192
+ # shape: (batch_size, num_tags)
193
+ score = torch.where(mask[i].unsqueeze(1), next_score, score)
194
+ history.append(indices)
195
+
196
+ history = torch.stack(history, dim=0)
197
+
198
+ # End transition score
199
+ # shape: (batch_size, num_tags)
200
+ score += self.end_transitions
201
+
202
+ # Now, compute the best path for each sample
203
+
204
+ # shape: (batch_size,)
205
+ seq_ends = mask.long().sum(dim=0) - 1
206
+ best_tags_list = []
207
+
208
+ for idx in range(batch_size):
209
+ # Find the tag which maximizes the score at the last timestep; this is our best tag
210
+ # for the last timestep
211
+ _, best_last_tag = score[idx].max(dim=0)
212
+ best_tags = [best_last_tag]
213
+
214
+ # We trace back where the best last tag comes from, append that to our best tag
215
+ # sequence, and trace it back again, and so on
216
+ for i, hist in enumerate(torch.flip(history[:seq_ends[idx]], dims=(0,))):
217
+ best_last_tag = hist[idx][best_tags[-1]]
218
+ best_tags.append(best_last_tag)
219
+
220
+ best_tags = torch.stack(best_tags, dim=0)
221
+
222
+ # Reverse the order because we start from the last timestep
223
+ best_tags_list.append(torch.flip(best_tags, dims=(0,)))
224
+
225
+ best_tags_list = nn.utils.rnn.pad_sequence(best_tags_list, batch_first=True, padding_value=0)
226
+
227
+ return best_tags_list
228
+
229
+ """### CRFLayer
230
+ - Forward: decide output logits basaed on backbone network
231
+ - Decode: decode based on CRF weights
232
+ """
233
+
234
+ class CRFLayer(nn.Module):
235
+ def __init__(self, embedding_size, n_labels):
236
+
237
+ super(CRFLayer, self).__init__()
238
+ self.dropout = nn.Dropout(0.1)
239
+ self.output_dense = nn.Linear(embedding_size,n_labels)
240
+ self.crf = CRF(n_labels, batch_first=True)
241
+ self.token_from_subtoken = TokenFromSubtoken()
242
+
243
+ # Forward: decide output logits basaed on backbone network
244
+ def forward(self, embedding, mask):
245
+ logits = self.output_dense(self.dropout(embedding))
246
+ logits = self.token_from_subtoken(logits, mask)
247
+ pad_mask = self.token_from_subtoken(mask.unsqueeze(-1), mask).squeeze(-1).bool()
248
+ return logits, pad_mask
249
+
250
+ # Decode: decode based on CRF weights
251
+ def decode(self, logits, pad_mask):
252
+ return self.crf.decode(logits, pad_mask)
253
+
254
+ # Evaluation Loss: calculate mean log likelihood of CRF layer
255
+ def eval_loss(self, logits, targets, pad_mask):
256
+ mean_log_likelihood = self.crf(logits, targets, pad_mask, reduction='sum').mean()
257
+ return -mean_log_likelihood
258
+
259
+ """### NERModel
260
+ - Roberta Model with CRF Layer
261
+ """
262
+
263
+ class NERModel(nn.Module):
264
+
265
+ def __init__(self, n_labels:int, roberta_path:str):
266
+ super(NERModel,self).__init__()
267
+ self.roberta = XLMRobertaModel.from_pretrained(roberta_path)
268
+ self.crf = CRFLayer(self.roberta.config.hidden_size, n_labels)
269
+
270
+ # Forward: pass embedings to CRF layer in order to evaluate logits from suboword sequence
271
+ def forward(self,
272
+ input_ids:torch.Tensor,
273
+ attention_mask:torch.Tensor,
274
+ token_type_ids:torch.Tensor,
275
+ mask:torch.Tensor) -> torch.Tensor:
276
+
277
+ embedding = self.roberta(input_ids=input_ids,
278
+ attention_mask=attention_mask,
279
+ token_type_ids=token_type_ids)[0]
280
+ logits, pad_mask = self.crf(embedding, mask)
281
+ return logits, pad_mask
282
+
283
+ # Disable Gradient and Predict with model
284
+ @torch.no_grad()
285
+ def predict(self, inputs:Tuple[torch.Tensor]) -> torch.Tensor:
286
+ input_ids, attention_mask, token_type_ids, mask = inputs
287
+ logits, pad_mask = self(input_ids, attention_mask, token_type_ids, mask)
288
+ decoded = self.crf.decode(logits, pad_mask)
289
+ return decoded, pad_mask
290
+
291
+ # Decode: pass to crf decoder and decode based on CRF weights
292
+ def decode(self, logits, pad_mask):
293
+ """Decode logits using CRF weights
294
+ """
295
+ return self.crf.decode(logits, pad_mask)
296
+
297
+ # Evaluation Loss: pass to crf eval_loss and calculate mean log likelihood of CRF layer
298
+ def eval_loss(self, logits, targets, pad_mask):
299
+ return self.crf.eval_loss(logits, targets, pad_mask)
300
+
301
+ # Determine number of layers to be fine-tuned (!freeze)
302
+ def freeze_roberta(self, n_freeze:int=6):
303
+ for param in self.roberta.parameters():
304
+ param.requires_grad = False
305
+
306
+ for param in self.roberta.encoder.layer[n_freeze:].parameters():
307
+ param.requires_grad = True
308
+
309
+ """### NERTokenizer
310
+ - NLTK tokenizer along with XLMRobertaTokenizerFast tokenizer
311
+ - Code adapted from the following [file](https://github.com/ugurcanozalp/multilingual-ner/blob/main/multiner/utils/custom_tokenizer.py)
312
+ """
313
+
314
+ class NERTokenizer(object):
315
+
316
+ MAX_LEN=512
317
+ BATCH_LENGTH_LIMT = 380 # Max number of roberta tokens in one sentence.
318
+
319
+ # Modified version of http://stackoverflow.com/questions/36353125/nltk-regular-expression-tokenizer
320
+ PATTERN = r'''(?x) # set flag to allow verbose regexps
321
+ (?:[A-Z]\.)+ # abbreviations, e.g. U.S.A. or U.S.A #
322
+ | (?:\d+\.) # numbers
323
+ | \w+(?:[-.]\w+)* # words with optional internal hyphens
324
+ | \$?\d+(?:.\d+)?%? # currency and percentages, e.g. $12.40, 82%
325
+ | \.\.\. # ellipsis, and special chars below, includes ], [
326
+ | [-\]\[.؟،؛;"'?,():_`“”/°º‘’″…#$%()*+<>=@\\^_{}|~❑&§\!]
327
+ | \u200c
328
+ '''
329
+
330
+ def __init__(self, base_model:str, to_device:str='cpu'):
331
+ super(NERTokenizer,self).__init__()
332
+ self.roberta_tokenizer = XLMRobertaTokenizerFast.from_pretrained(base_model, do_lower_case=False, padding=True, truncation=True)
333
+ self.to_device = to_device
334
+
335
+ self.word_tokenizer = RegexpTokenizer(self.PATTERN)
336
+ self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
337
+
338
+ # tokenize batch of tokens
339
+ def tokenize_batch(self, inputs, pad_to = None) -> torch.Tensor:
340
+ batch = [inputs] if isinstance(inputs[0], str) else inputs
341
+
342
+ input_ids, attention_mask, token_type_ids, mask = [], [], [], []
343
+ for tokens in batch:
344
+ input_ids_tmp, attention_mask_tmp, token_type_ids_tmp, mask_tmp = self._tokenize_words(tokens)
345
+ input_ids.append(input_ids_tmp)
346
+ attention_mask.append(attention_mask_tmp)
347
+ token_type_ids.append(token_type_ids_tmp)
348
+ mask.append(mask_tmp)
349
+
350
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.roberta_tokenizer.pad_token_id)
351
+ attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
352
+ token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)
353
+ mask = pad_sequence(mask, batch_first=True, padding_value=0)
354
+ # truncate MAX_LEN
355
+ if input_ids.shape[-1]>self.MAX_LEN:
356
+ input_ids = input_ids[:,:,:self.MAX_LEN]
357
+ attention_mask = attention_mask[:,:,:self.MAX_LEN]
358
+ token_type_ids = token_type_ids[:,:,:self.MAX_LEN]
359
+ mask = mask[:,:,:self.MAX_LEN]
360
+
361
+ # extend pad
362
+ elif pad_to is not None and pad_to>input_ids.shape[1]:
363
+ bs = input_ids.shape[0]
364
+ padlen = pad_to-input_ids.shape[1]
365
+
366
+ input_ids_append = torch.tensor([self.roberta_tokenizer.pad_token_id], dtype=torch.long).repeat([bs, padlen]).to(self.to_device)
367
+ input_ids = torch.cat([input_ids, input_ids_append], dim=-1)
368
+
369
+ attention_mask_append = torch.tensor([0], dtype=torch.long).repeat([bs, padlen]).to(self.to_device)
370
+ attention_mask = torch.cat([attention_mask, attention_mask_append], dim=-1)
371
+
372
+ token_type_ids_append = torch.tensor([0], dtype=torch.long).repeat([bs, padlen]).to(self.to_device)
373
+ token_type_ids = torch.cat([token_type_ids, token_type_ids_append], dim=-1)
374
+
375
+ mask_append = torch.tensor([0], dtype=torch.long).repeat([bs, padlen]).to(self.to_device)
376
+ mask = torch.cat([mask, mask_append], dim=-1)
377
+
378
+ # truncate pad
379
+ elif pad_to is not None and pad_to<input_ids.shape[1]:
380
+ input_ids = input_ids[:,:,:pad_to]
381
+ attention_mask = attention_mask[:,:,:pad_to]
382
+ token_type_ids = token_type_ids[:,:,:pad_to]
383
+ mask = mask[:,:,:pad_to]
384
+
385
+ if isinstance(inputs[0], str):
386
+ return input_ids[0], attention_mask[0], token_type_ids[0], mask[0]
387
+ else:
388
+ return input_ids, attention_mask, token_type_ids, mask
389
+
390
+ # tokenize list of words with roberta tokenizer
391
+ def _tokenize_words(self, words):
392
+ roberta_tokens = []
393
+ mask = []
394
+ for word in words:
395
+ subtokens = self.roberta_tokenizer.tokenize(word)
396
+ roberta_tokens+=subtokens
397
+ n_subtoken = len(subtokens)
398
+ if n_subtoken>=1:
399
+ mask = mask + [1] + [0]*(n_subtoken-1)
400
+
401
+ # add special tokens [CLS] and [SeP]
402
+ roberta_tokens = [self.roberta_tokenizer.cls_token] + roberta_tokens + [self.roberta_tokenizer.sep_token]
403
+ mask = [0] + mask + [0]
404
+ input_ids = torch.tensor(self.roberta_tokenizer.convert_tokens_to_ids(roberta_tokens), dtype=torch.long).to(self.to_device)
405
+ attention_mask = torch.ones(len(mask), dtype=torch.long).to(self.to_device)
406
+ token_type_ids = torch.zeros(len(mask), dtype=torch.long).to(self.to_device)
407
+ mask = torch.tensor(mask, dtype=torch.long).to(self.to_device)
408
+ return input_ids, attention_mask, token_type_ids, mask
409
+
410
+ # sent_to_token: yield each sentence token with positional span using nltk
411
+ def sent_to_token(self, raw_text):
412
+ for offset, ending in self.sent_tokenizer.span_tokenize(raw_text):
413
+ sub_text = raw_text[offset:ending]
414
+ words, spans = [], []
415
+ flush = False
416
+ total_subtoken = 0
417
+ for start, end in self.word_tokenizer.span_tokenize(sub_text):
418
+ flush = True
419
+ start += offset
420
+ end += offset
421
+ words.append(raw_text[start:end])
422
+ spans.append((start,end))
423
+ total_subtoken += len(self.roberta_tokenizer.tokenize(words[-1]))
424
+ if (total_subtoken > self.BATCH_LENGTH_LIMT):
425
+ # Print
426
+ yield words[:-1],spans[:-1]
427
+ spans = spans[len(spans)-1:]
428
+ words = words[len(words)-1:]
429
+ total_subtoken = sum([len(self.roberta_tokenizer.tokenize(word)) for word in words])
430
+ flush = False
431
+
432
+ if flush and len(spans) > 0:
433
+ yield words,spans
434
+
435
+ # Extract (batch words span() from a raw sentence
436
+ def prepare_row_text(self, raw_text, batch_size=16):
437
+ words_list, spans_list = [], []
438
+ end_batch = False
439
+ for words, spans in self.sent_to_token(raw_text):
440
+ end_batch = True
441
+ words_list.append(words)
442
+ spans_list.append(spans)
443
+ if len(spans_list) >= batch_size:
444
+ input_ids, attention_mask, token_type_ids, mask = self.tokenize_batch(words_list)
445
+ yield (input_ids, attention_mask, token_type_ids, mask), words_list, spans_list
446
+ words_list, spans_list = [], []
447
+ if end_batch and len(words_list) > 0:
448
+ input_ids, attention_mask, token_type_ids, mask = self.tokenize_batch(words_list)
449
+ yield (input_ids, attention_mask, token_type_ids, mask), words_list, spans_list
450
+
451
+ """### NER
452
+ NER Interface : We Use this interface to infer sentence Time-Date tags.
453
+ """
454
+
455
+ class NER(object):
456
+
457
+ def __init__(self, model_path, tags):
458
+
459
+ self.tags = tags
460
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
461
+ # Load Pre-Trained model
462
+ roberta_path = "xlm-roberta-base"
463
+ self.model = NERModel(n_labels=len(self.tags), roberta_path=roberta_path).to(self.device)
464
+ # Load Fine-Tuned model
465
+ state_dict = torch.load(model_path)
466
+ self.model.load_state_dict(state_dict, strict=False)
467
+ # Enable Evaluation mode
468
+ self.model.eval()
469
+ self.tokenizer = NERTokenizer(base_model=roberta_path, to_device=self.device)
470
+
471
+ # Predict and Pre/Post-Process the input/output
472
+ @torch.no_grad()
473
+ def __call__(self, raw_text):
474
+
475
+ outputs_flat, spans_flat, entities = [], [], []
476
+ for batch, words, spans in self.tokenizer.prepare_row_text(raw_text):
477
+ output, pad_mask = self.model.predict(batch)
478
+ outputs_flat.extend(output[pad_mask.bool()].reshape(-1).tolist())
479
+ spans_flat += sum(spans, [])
480
+
481
+ for tag_idx,(start,end) in zip(outputs_flat,spans_flat):
482
+ tag = self.tags[tag_idx]
483
+ # filter out O tags
484
+ if tag != 'O':
485
+ entities.append({'Text': raw_text[start:end],
486
+ 'Tag': tag,
487
+ 'Start':start,
488
+ 'End': end})
489
+
490
+ return entities