import torch from typing import Callable, List, Tuple, Union from functools import partial import itertools from seqeval.scheme import Tokens, IOB2, IOBES from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizerBase from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm') from thai2transformers.preprocess import rm_useless_spaces SPIECE = '▁' class TokenClassificationPipeline: def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer, lowercase=False, space_token='<_>', device: int = -1, group_entities: bool = False, strict: bool = False, tag_delimiter: str = '-', scheme: str = 'IOB', use_crf=False, remove_spiece=True): super().__init__() assert isinstance(tokenizer, PreTrainedTokenizerBase) # assert isinstance(model, PreTrainedModel) self.model = model self.tokenizer = tokenizer self.pretokenizer = pretokenizer self.lowercase = lowercase self.space_token = space_token self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}' self.group_entities = group_entities self.strict = strict self.tag_delimiter = tag_delimiter self.scheme = scheme self.id2label = self.model.config.id2label self.label2id = self.model.config.label2id self.use_crf = use_crf self.remove_spiece = remove_spiece self.model.to(self.device) def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: if self.lowercase: inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs)) inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs)) tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs)) tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \ list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens)) return tokens def _inference(self, input: str): tokens = [[self.tokenizer.bos_token]] + \ [self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \ [[self.tokenizer.eos_token]] ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens] flatten_tokens = list(itertools.chain(*tokens)) flatten_ids = list(itertools.chain(*ids)) input_ids = torch.LongTensor([flatten_ids]).to(self.device) if self.use_crf: out = self.model(input_ids=input_ids) else: out = self.model(input_ids=input_ids, return_dict=True) probs = torch.softmax(out['logits'], dim=-1) vals, indices = probs.topk(1) indices_np = indices.detach().cpu().numpy().reshape(-1) list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] )) merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids) if self.remove_spiece: merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds)) # remove start and end tokens merged_preds_removed_bos_eos = merged_preds[1:-1] # convert to list of Dict objects merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \ for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ] if (not self.group_entities or self.scheme == None) and self.strict == True: return merged_preds_return_dict elif not self.group_entities and self.strict == False: tags = list(map(lambda x: x['entity'], merged_preds_return_dict)) processed_tags = self._fix_incorrect_tags(tags) for i, item in enumerate(merged_preds_return_dict): merged_preds_return_dict[i]['entity'] = processed_tags[i] return merged_preds_return_dict elif self.group_entities: return self._group_entities(merged_preds_removed_bos_eos) def __call__(self, inputs: Union[str, List[str]]): """ """ if type(inputs) == str: return self._inference(inputs) if type(inputs) == list: results = [ self._inference(text) for text in inputs] return results def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]): token_mapping = [ ] for i in range(0, len(ids)): for j in range(0, len(ids[i])): token_mapping.append(i) grouped_subtokens = [] _subtoken = [] prev_idx = 0 for i, (subtoken, label) in enumerate(preds): current_idx = token_mapping[i] if prev_idx != current_idx: grouped_subtokens.append(_subtoken) _subtoken = [(subtoken, label)] if i == len(preds) -1: _subtoken = [(subtoken, label)] grouped_subtokens.append(_subtoken) elif i == len(preds) -1: _subtoken += [(subtoken, label)] grouped_subtokens.append(_subtoken) else: _subtoken += [(subtoken, label)] prev_idx = current_idx merged_subtokens = [] _merged_subtoken = '' for subtoken_group in grouped_subtokens: first_token_pred = subtoken_group[0][1] _merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group))) merged_subtokens.append((_merged_subtoken, first_token_pred)) return merged_subtokens def _fix_incorrect_tags(self, tags: List[str]) -> List[str]: I_PREFIX = f'I{self.tag_delimiter}' E_PREFIX = f'E{self.tag_delimiter}' B_PREFIX = f'B{self.tag_delimiter}' O_PREFIX = 'O' previous_tag_ne = None for i, current_tag in enumerate(tags): current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX if i == 0 and (current_tag.startswith(I_PREFIX) or \ current_tag.startswith(E_PREFIX)): # if a NE tag (with I-, or E- prefix) occuring at the begining of sentence # e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O) # then, change the prefix of the current tag to B{tag_delimiter} tags[i] = B_PREFIX + tags[i][2:] elif i >= 1 and tags[i-1] == O_PREFIX and ( current_tag.startswith(I_PREFIX) or \ current_tag.startswith(E_PREFIX)): # if a NE tag (with I-, or E- prefix) occuring after O tag # e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O) # then, change the prefix of the current tag to B{tag_delimiter} tags[i] = B_PREFIX + tags[i][2:] elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \ tags[i-1].startswith(E_PREFIX) or \ tags[i-1].startswith(B_PREFIX)) and \ ( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \ previous_tag_ne != current_tag_ne: # if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE # e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER) # then, change the prefix of the current tag to B{tag_delimiter} tags[i] = B_PREFIX + tags[i][2:] elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and ( current_tag.startswith(I_PREFIX) or \ current_tag.startswith(E_PREFIX)): # if a NE tag (with I-, or E- prefix) occuring at the end of sentence # e.g. (O, O, I-LOC) , (O, O, E-LOC) # then, change the prefix of the current tag to B{tag_delimiter} tags[i] = B_PREFIX + tags[i][2:] previous_tag_ne = current_tag_ne return tags def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]: if self.scheme not in ['IOB', 'IOBES', 'IOBE']: raise AttributeError() tokens, tags = zip(*ner_tags) tokens, tags = list(tokens), list(tags) if self.scheme == 'IOBE': # Replace E prefix with I prefix tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags)) if self.scheme == 'IOBES': # Replace E prefix with I prefix and replace S prefix with B tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags)) tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags)) if not self.strict: tags = self._fix_incorrect_tags(tags) ent = Tokens(tokens=tags, scheme=IOB2, suffix=False, delimiter=self.tag_delimiter) ne_position_mappings = ent.entities token_positions = [] curr_len = 0 tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens)) for i, token in enumerate(tokens): token_len = len(token) if i == 0: token_positions.append((0, curr_len + token_len)) else: token_positions.append((curr_len, curr_len + token_len )) curr_len += token_len print(f'token_positions: {list(zip(tokens, token_positions))}') begin_end_pos = [] begin_end_char_pos = [] accum_char_len = 0 for i, ne_position_mapping in enumerate(ne_position_mappings): print(f'ne_position_mapping.start: {ne_position_mapping.start}') print(f'ne_position_mapping.end: {ne_position_mapping.end}\n') begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end)) begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1])) print(f'begin_end_pos: {begin_end_pos}') print(f'begin_end_char_pos: {begin_end_char_pos}') j = 0 # print(f'tokens: {tokens}') for i, pos_tuple in enumerate(begin_end_pos): # print(f'j = {j}') if pos_tuple[0] > 0 and i == 0: ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0])) j += 1 if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 : ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0])) j += 1 j += 1 print('ne_position_mappings', ne_position_mappings) groups = [] k = 0 for i, ne_position_mapping in enumerate(ne_position_mappings): if type(ne_position_mapping) != tuple: ne_position_mapping = ne_position_mapping.to_tuple() ne = ne_position_mapping[1] text = '' for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]): _token = tokens[ne_position] text += _token if _token != self.space_token else ' ' if ne.lower() != 'o': groups.append({ 'entity_group': ne, 'word': text, 'begin_char_index': begin_end_char_pos[k][0] }) k+=1 else: groups.append({ 'entity_group': ne, 'word': text, }) return groups