import json import logging from typing import List, Any import copy import torch from torch.utils.data import Dataset from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer from util.process_data import Sample, Entity, EntityType, EntityTypeSet, SampleList, Token, Relation from util.configuration import InferenceConfiguration valid_relations = { # head : [tail, ...] "StatedKeyFigure": ["StatedKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"], "DeclarativeKeyFigure": ["DeclarativeKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"], "StatedExpression": ["Unit", "Factor", "Range", "Condition"], "DeclarativeExpression": ["DeclarativeExpression", "Unit", "Factor", "Range", "Condition"], "Condition": ["Condition", "StatedExpression", "DeclarativeExpression"], "Range": ["Range"] } class TokenClassificationDataset(Dataset): """ Pytorch Dataset """ def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) class TransformersInference(): def __init__(self, config: InferenceConfiguration): super().__init__() self.__logger = logging.getLogger(self.__class__.__name__) self.__logger.info(f"Load Configuration: {config.dict()}") with open(f"classification.json", mode='r', encoding="utf-8") as f: self.__entity_type_set = EntityTypeSet.parse_obj(json.load(f)) self.__entity_type_label_to_id_mapping = {x.label: x.idx for x in self.__entity_type_set.all_types()} self.__entity_type_id_to_label_mapping = {x.idx: x.label for x in self.__entity_type_set.all_types()} self.__logger.info("Load Model: " + config.model_path_keyfigure) self.__tokenizer = AutoTokenizer.from_pretrained(config.transformer_model, padding="max_length", max_length=512, truncation=True) self.__model = AutoModelForTokenClassification.from_pretrained(config.model_path_keyfigure, num_labels=( len(self.__entity_type_set))) self.__trainer = Trainer(model=self.__model) self.__merge_entities = config.merge_entities self.__split_len = config.split_len self.__extract_relations = config.extract_relations # add special tokens entity_groups = self.__entity_type_set.groups num_entity_groups = len(entity_groups) lst_special_tokens = ["[REL]", "[SUB]", "[/SUB]", "[OBJ]", "[/OBJ]"] for grp_idx, grp in enumerate(entity_groups): lst_special_tokens.append(f"[GRP-{grp_idx:02d}]") lst_special_tokens.extend([f"[ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity]) lst_special_tokens.extend([f"[/ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity]) lst_special_tokens = sorted(list(set(lst_special_tokens))) special_tokens_dict = {'additional_special_tokens': lst_special_tokens } num_added_toks = self.__tokenizer.add_special_tokens(special_tokens_dict) self.__logger.info(f"Added {num_added_toks} new special tokens. All special tokens: '{self.__tokenizer.all_special_tokens}'") self.__logger.info("Initialization completed.") def run_inference(self, sample_list: SampleList): group_predictions = [] group_entity_ids = [] self.__logger.info("Predict Entities ...") for grp_idx, grp in enumerate(self.__entity_type_set.groups): token_lists = [[x.text for x in sample.tokens] for sample in sample_list.samples] predictions = self.__get_predictions(token_lists, f"[GRP-{grp_idx:02d}]") group_entity_ids_ = [] for sample, prediction_per_tokens in zip(sample_list.samples, predictions): group_entity_ids_.append(self.generate_response_entities(sample, prediction_per_tokens, grp_idx)) group_predictions.append(predictions) group_entity_ids.append(group_entity_ids_) if self.__extract_relations: self.__logger.info("Predict Relations ...") self.__do_extract_relations(sample_list, group_predictions, group_entity_ids) def __do_extract_relations(self, sample_list, group_predictions, group_entity_ids): id_of_non_entity = self.__entity_type_set.id_of_non_entity for sample_idx, sample in enumerate(sample_list.samples): masked_tokens = [] masked_tokens_align = [] # create SUB-Mask for every entity that can be a head head_entities = [entity_ for entity_ in sample.entities if entity_.ent_type.label in list(valid_relations.keys())] for entity_ in head_entities: ent_masked_tokens = [] ent_masked_tokens_align = [] last_preds = [id_of_non_entity for group in group_predictions] last_ent_ids = [-1 for group in group_entity_ids] for token_idx, token in enumerate(sample.tokens): for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): pred = group[sample_idx][token_idx] ent_id = ent_ids[sample_idx][token_idx] if last_pred != pred and last_pred != id_of_non_entity: mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]" ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask]) ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)]) for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): pred = group[sample_idx][token_idx] ent_id = ent_ids[sample_idx][token_idx] if last_pred != pred and pred != id_of_non_entity: mask = "[SUB]" if ent_id == entity_.id else "[OBJ]" ent_masked_tokens.extend([mask, f"[ENT-{pred:02d}]"]) ent_masked_tokens_align.extend([str(ent_id), str(ent_id)]) ent_masked_tokens.append(token.text) ent_masked_tokens_align.append(token.text) for idx, group in enumerate(group_predictions): last_preds[idx] = group[sample_idx][token_idx] for idx, group in enumerate(group_entity_ids): last_ent_ids[idx] = group[sample_idx][token_idx] for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): pred = group[sample_idx][token_idx] ent_id = ent_ids[sample_idx][token_idx] if last_pred != id_of_non_entity: mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]" ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask]) ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)]) masked_tokens.append(ent_masked_tokens) masked_tokens_align.append(ent_masked_tokens_align) rel_predictions = self.__get_predictions(masked_tokens, "[REL]") self.generate_response_relations(sample, head_entities, masked_tokens_align, rel_predictions) def generate_response_entities(self, sample: Sample, predictions_per_tokens: List[int], grp_idx: int): entities = [] entity_ids = [] id_of_non_entity = self.__entity_type_set.id_of_non_entity idx = grp_idx * 1000 for token, prediction in zip(sample.tokens, predictions_per_tokens): if id_of_non_entity == prediction: entity_ids.append(-1) continue idx += 1 entities.append(self.__build_entity(idx, prediction, token)) entity_ids.append(idx) if self.__merge_entities: entities = self.__do_merge_entities(copy.deepcopy(entities)) prev_pred = id_of_non_entity for idx, pred in enumerate(predictions_per_tokens): if prev_pred == pred and idx > 0: entity_ids[idx] = entity_ids[idx-1] prev_pred = pred sample.entities += entities tags = sample.tags if len(sample.tags) > 0 else [self.__entity_type_set.id_of_non_entity] * len(sample.tokens) for tag_id, tok in enumerate(sample.tokens): for ent in entities: if tok.start >= ent.start and tok.start < ent.end: tags[tag_id] = ent.ent_type.idx logging.info(tags) sample.tags = tags return entity_ids def generate_response_relations(self, sample: Sample, head_entities: List[Entity], masked_tokens_align: List[List[str]], rel_predictions: List[List[int]]): relations = [] id_of_non_entity = self.__entity_type_set.id_of_non_entity idx = 0 for entity_, align_per_ent, prediction_per_ent in zip(head_entities, masked_tokens_align, rel_predictions): for token, prediction in zip(align_per_ent, prediction_per_ent): if id_of_non_entity == prediction: continue try: tail = int(token) except: continue if not self.__validate_relation(sample.entities, entity_.id, tail, prediction): continue idx += 1 relations.append(self.__build_relation(idx, entity_.id, tail, prediction)) sample.relations = relations def __validate_relation(self, entities: List[Entity], head: int, tail: int, prediction: int): if head == tail: return False head_ents = [ent.ent_type.label for ent in entities if ent.id==head] tail_ents = [ent.ent_type.label for ent in entities if ent.id==tail] if len(head_ents) > 0: head_ent = head_ents[0] else: return False if len(tail_ents) > 0: tail_ent = tail_ents[0] else: return False return tail_ent in valid_relations[head_ent] def __build_entity(self, idx: int, prediction: int, token: Token) -> Entity: return Entity( id=idx, text=token.text, start=token.start, end=token.end, ent_type=EntityType( idx=prediction, label=self.__entity_type_id_to_label_mapping[prediction] ) ) def __build_relation(self, idx: int, head: int, tail: int, prediction: int) -> Relation: return Relation( id=idx, head=head, tail=tail, rel_type=EntityType( idx=prediction, label=self.__entity_type_id_to_label_mapping[prediction] ) ) def __do_merge_entities(self, input_ents_): out_ents = list() current_ent = None for ent in input_ents_: if current_ent is None: current_ent = ent else: idx_diff = ent.start - current_ent.end if ent.ent_type.idx == current_ent.ent_type.idx and idx_diff <= 1: current_ent.end = ent.end current_ent.text += (" " if idx_diff == 1 else "") + ent.text else: out_ents.append(current_ent) current_ent = ent if current_ent is not None: out_ents.append(current_ent) return out_ents def __get_predictions(self, token_lists: List[List[str]], trigger: str) -> List[List[int]]: """ Get predictions of Transformer Sequence Labeling model """ if self.__split_len > 0: token_lists_split = self.__do_split_sentences(token_lists, self.__split_len) predictions = [] for sample_token_lists in token_lists_split: sample_token_lists_trigger = [[trigger]+sample for sample in sample_token_lists] val_encodings = self.__tokenizer(sample_token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt" val_labels = [] for i in range(len(sample_token_lists_trigger)): word_ids = val_encodings.word_ids(batch_index=i) label_ids = [0 for _ in word_ids] val_labels.append(label_ids) val_dataset = TokenClassificationDataset(val_encodings, val_labels) predictions_raw, _, _ = self.__trainer.predict(val_dataset) predictions_align = self.__align_predictions(predictions_raw, val_encodings) confidence = [[max(token) for token in sample] for sample in predictions_align] predictions_sample = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align] predictions_part = [] for tok, pred in zip(sample_token_lists_trigger, predictions_sample): if trigger == "[REL]" and "[SUB]" not in tok: predictions_part += [self.__entity_type_set.id_of_non_entity] * len(pred) else: predictions_part += pred predictions.append(predictions_part) # predictions.append([j for i in predictions_sample for j in i])) else: token_lists_trigger = [[trigger]+sample for sample in token_lists] val_encodings = self.__tokenizer(token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt" val_labels = [] for i in range(len(token_lists_trigger)): word_ids = val_encodings.word_ids(batch_index=i) label_ids = [0 for _ in word_ids] val_labels.append(label_ids) val_dataset = TokenClassificationDataset(val_encodings, val_labels) predictions_raw, _, _ = self.__trainer.predict(val_dataset) predictions_align = self.__align_predictions(predictions_raw, val_encodings) confidence = [[max(token) for token in sample] for sample in predictions_align] predictions = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align] return predictions def __do_split_sentences(self, tokens_: List[List[str]], split_len_ = 200) -> List[List[List[str]]]: # split token lists into shorter lists res_tokens = [] for tok_lst in tokens_: res_tokens_sample = [] length = len(tok_lst) if length > split_len_: num_lists = length // split_len_ + (1 if (length % split_len_) > 0 else 0) new_length = int(length / num_lists) + 1 self.__logger.info(f"Splitting a list of {length} elements into {num_lists} lists of length {new_length}..") start_idx = 0 for i in range(num_lists): end_idx = min(start_idx + new_length, length) if "\n" in tok_lst[start_idx]: tok_lst[start_idx] = "." if "\n" in tok_lst[end_idx-1]: tok_lst[end_idx-1] = "." res_tokens_sample.append(tok_lst[start_idx:end_idx]) start_idx = end_idx res_tokens.append(res_tokens_sample) else: res_tokens.append([tok_lst]) return res_tokens def __align_predictions(self, predictions, tokenized_inputs, sum_all_tokens=False) -> List[List[List[float]]]: """ Align predicted labels from Transformer Tokenizer """ confidence = [] id_of_non_entity = self.__entity_type_set.id_of_non_entity for i, tagset in enumerate(predictions): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None token_confidence = [] for k, word_idx in enumerate(word_ids): try: tok_conf = [value for value in tagset[k]] except TypeError: # use the object itself it if's not iterable tok_conf = tagset[k] if word_idx is not None: # add nonentity tokens if there is a gap in word ids (usually caused by a newline token) if previous_word_idx is not None: diff = word_idx - previous_word_idx for i in range(diff - 1): tmp = [0 for _ in tok_conf] tmp[id_of_non_entity] = 1.0 token_confidence.append(tmp) # add confidence value if this is the first token of the word if word_idx != previous_word_idx: token_confidence.append(tok_conf) else: # if sum_all_tokens=True the confidence for all tokens of one word will be summarized if sum_all_tokens: token_confidence[-1] = [a + b for a, b in zip(token_confidence[-1], tok_conf)] previous_word_idx = word_idx confidence.append(token_confidence) return confidence