Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from typing import List, Any | |
import copy | |
import torch | |
from torch.utils.data import Dataset | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer | |
from util.process_data import Sample, Entity, EntityType, EntityTypeSet, SampleList, Token, Relation | |
from util.configuration import InferenceConfiguration | |
valid_relations = { # head : [tail, ...] | |
"StatedKeyFigure": ["StatedKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"], | |
"DeclarativeKeyFigure": ["DeclarativeKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"], | |
"StatedExpression": ["Unit", "Factor", "Range", "Condition"], | |
"DeclarativeExpression": ["DeclarativeExpression", "Unit", "Factor", "Range", "Condition"], | |
"Condition": ["Condition", "StatedExpression", "DeclarativeExpression"], | |
"Range": ["Range"] | |
} | |
class TokenClassificationDataset(Dataset): | |
""" Pytorch Dataset """ | |
def __init__(self, encodings, labels): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
item['labels'] = torch.tensor(self.labels[idx]) | |
return item | |
def __len__(self): | |
return len(self.labels) | |
class TransformersInference(): | |
def __init__(self, config: InferenceConfiguration): | |
super().__init__() | |
self.__logger = logging.getLogger(self.__class__.__name__) | |
self.__logger.info(f"Load Configuration: {config.dict()}") | |
with open(f"classification.json", mode='r', encoding="utf-8") as f: | |
self.__entity_type_set = EntityTypeSet.parse_obj(json.load(f)) | |
self.__entity_type_label_to_id_mapping = {x.label: x.idx for x in self.__entity_type_set.all_types()} | |
self.__entity_type_id_to_label_mapping = {x.idx: x.label for x in self.__entity_type_set.all_types()} | |
self.__logger.info("Load Model: " + config.model_path_keyfigure) | |
self.__tokenizer = AutoTokenizer.from_pretrained(config.transformer_model, | |
padding="max_length", max_length=512, truncation=True) | |
self.__model = AutoModelForTokenClassification.from_pretrained(config.model_path_keyfigure, num_labels=( | |
len(self.__entity_type_set))) | |
self.__trainer = Trainer(model=self.__model) | |
self.__merge_entities = config.merge_entities | |
self.__split_len = config.split_len | |
self.__extract_relations = config.extract_relations | |
# add special tokens | |
entity_groups = self.__entity_type_set.groups | |
num_entity_groups = len(entity_groups) | |
lst_special_tokens = ["[REL]", "[SUB]", "[/SUB]", "[OBJ]", "[/OBJ]"] | |
for grp_idx, grp in enumerate(entity_groups): | |
lst_special_tokens.append(f"[GRP-{grp_idx:02d}]") | |
lst_special_tokens.extend([f"[ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity]) | |
lst_special_tokens.extend([f"[/ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity]) | |
lst_special_tokens = sorted(list(set(lst_special_tokens))) | |
special_tokens_dict = {'additional_special_tokens': lst_special_tokens } | |
num_added_toks = self.__tokenizer.add_special_tokens(special_tokens_dict) | |
self.__logger.info(f"Added {num_added_toks} new special tokens. All special tokens: '{self.__tokenizer.all_special_tokens}'") | |
self.__logger.info("Initialization completed.") | |
def run_inference(self, sample_list: SampleList): | |
group_predictions = [] | |
group_entity_ids = [] | |
self.__logger.info("Predict Entities ...") | |
for grp_idx, grp in enumerate(self.__entity_type_set.groups): | |
token_lists = [[x.text for x in sample.tokens] for sample in sample_list.samples] | |
predictions = self.__get_predictions(token_lists, f"[GRP-{grp_idx:02d}]") | |
group_entity_ids_ = [] | |
for sample, prediction_per_tokens in zip(sample_list.samples, predictions): | |
group_entity_ids_.append(self.generate_response_entities(sample, prediction_per_tokens, grp_idx)) | |
group_predictions.append(predictions) | |
group_entity_ids.append(group_entity_ids_) | |
if self.__extract_relations: | |
self.__logger.info("Predict Relations ...") | |
self.__do_extract_relations(sample_list, group_predictions, group_entity_ids) | |
def __do_extract_relations(self, sample_list, group_predictions, group_entity_ids): | |
id_of_non_entity = self.__entity_type_set.id_of_non_entity | |
for sample_idx, sample in enumerate(sample_list.samples): | |
masked_tokens = [] | |
masked_tokens_align = [] | |
# create SUB-Mask for every entity that can be a head | |
head_entities = [entity_ for entity_ in sample.entities if entity_.ent_type.label in list(valid_relations.keys())] | |
for entity_ in head_entities: | |
ent_masked_tokens = [] | |
ent_masked_tokens_align = [] | |
last_preds = [id_of_non_entity for group in group_predictions] | |
last_ent_ids = [-1 for group in group_entity_ids] | |
for token_idx, token in enumerate(sample.tokens): | |
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): | |
pred = group[sample_idx][token_idx] | |
ent_id = ent_ids[sample_idx][token_idx] | |
if last_pred != pred and last_pred != id_of_non_entity: | |
mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]" | |
ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask]) | |
ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)]) | |
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): | |
pred = group[sample_idx][token_idx] | |
ent_id = ent_ids[sample_idx][token_idx] | |
if last_pred != pred and pred != id_of_non_entity: | |
mask = "[SUB]" if ent_id == entity_.id else "[OBJ]" | |
ent_masked_tokens.extend([mask, f"[ENT-{pred:02d}]"]) | |
ent_masked_tokens_align.extend([str(ent_id), str(ent_id)]) | |
ent_masked_tokens.append(token.text) | |
ent_masked_tokens_align.append(token.text) | |
for idx, group in enumerate(group_predictions): | |
last_preds[idx] = group[sample_idx][token_idx] | |
for idx, group in enumerate(group_entity_ids): | |
last_ent_ids[idx] = group[sample_idx][token_idx] | |
for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids): | |
pred = group[sample_idx][token_idx] | |
ent_id = ent_ids[sample_idx][token_idx] | |
if last_pred != id_of_non_entity: | |
mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]" | |
ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask]) | |
ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)]) | |
masked_tokens.append(ent_masked_tokens) | |
masked_tokens_align.append(ent_masked_tokens_align) | |
rel_predictions = self.__get_predictions(masked_tokens, "[REL]") | |
self.generate_response_relations(sample, head_entities, masked_tokens_align, rel_predictions) | |
def generate_response_entities(self, sample: Sample, predictions_per_tokens: List[int], grp_idx: int): | |
entities = [] | |
entity_ids = [] | |
id_of_non_entity = self.__entity_type_set.id_of_non_entity | |
idx = grp_idx * 1000 | |
for token, prediction in zip(sample.tokens, predictions_per_tokens): | |
if id_of_non_entity == prediction: | |
entity_ids.append(-1) | |
continue | |
idx += 1 | |
entities.append(self.__build_entity(idx, prediction, token)) | |
entity_ids.append(idx) | |
if self.__merge_entities: | |
entities = self.__do_merge_entities(copy.deepcopy(entities)) | |
prev_pred = id_of_non_entity | |
for idx, pred in enumerate(predictions_per_tokens): | |
if prev_pred == pred and idx > 0: | |
entity_ids[idx] = entity_ids[idx-1] | |
prev_pred = pred | |
sample.entities += entities | |
tags = sample.tags if len(sample.tags) > 0 else [self.__entity_type_set.id_of_non_entity] * len(sample.tokens) | |
for tag_id, tok in enumerate(sample.tokens): | |
for ent in entities: | |
if tok.start >= ent.start and tok.start < ent.end: | |
tags[tag_id] = ent.ent_type.idx | |
logging.info(tags) | |
sample.tags = tags | |
return entity_ids | |
def generate_response_relations(self, sample: Sample, head_entities: List[Entity], masked_tokens_align: List[List[str]], rel_predictions: List[List[int]]): | |
relations = [] | |
id_of_non_entity = self.__entity_type_set.id_of_non_entity | |
idx = 0 | |
for entity_, align_per_ent, prediction_per_ent in zip(head_entities, masked_tokens_align, rel_predictions): | |
for token, prediction in zip(align_per_ent, prediction_per_ent): | |
if id_of_non_entity == prediction: | |
continue | |
try: | |
tail = int(token) | |
except: | |
continue | |
if not self.__validate_relation(sample.entities, entity_.id, tail, prediction): | |
continue | |
idx += 1 | |
relations.append(self.__build_relation(idx, entity_.id, tail, prediction)) | |
sample.relations = relations | |
def __validate_relation(self, entities: List[Entity], head: int, tail: int, prediction: int): | |
if head == tail: return False | |
head_ents = [ent.ent_type.label for ent in entities if ent.id==head] | |
tail_ents = [ent.ent_type.label for ent in entities if ent.id==tail] | |
if len(head_ents) > 0: | |
head_ent = head_ents[0] | |
else: | |
return False | |
if len(tail_ents) > 0: | |
tail_ent = tail_ents[0] | |
else: | |
return False | |
return tail_ent in valid_relations[head_ent] | |
def __build_entity(self, idx: int, prediction: int, token: Token) -> Entity: | |
return Entity( | |
id=idx, | |
text=token.text, | |
start=token.start, | |
end=token.end, | |
ent_type=EntityType( | |
idx=prediction, | |
label=self.__entity_type_id_to_label_mapping[prediction] | |
) | |
) | |
def __build_relation(self, idx: int, head: int, tail: int, prediction: int) -> Relation: | |
return Relation( | |
id=idx, | |
head=head, | |
tail=tail, | |
rel_type=EntityType( | |
idx=prediction, | |
label=self.__entity_type_id_to_label_mapping[prediction] | |
) | |
) | |
def __do_merge_entities(self, input_ents_): | |
out_ents = list() | |
current_ent = None | |
for ent in input_ents_: | |
if current_ent is None: | |
current_ent = ent | |
else: | |
idx_diff = ent.start - current_ent.end | |
if ent.ent_type.idx == current_ent.ent_type.idx and idx_diff <= 1: | |
current_ent.end = ent.end | |
current_ent.text += (" " if idx_diff == 1 else "") + ent.text | |
else: | |
out_ents.append(current_ent) | |
current_ent = ent | |
if current_ent is not None: | |
out_ents.append(current_ent) | |
return out_ents | |
def __get_predictions(self, token_lists: List[List[str]], trigger: str) -> List[List[int]]: | |
""" Get predictions of Transformer Sequence Labeling model """ | |
if self.__split_len > 0: | |
token_lists_split = self.__do_split_sentences(token_lists, self.__split_len) | |
predictions = [] | |
for sample_token_lists in token_lists_split: | |
sample_token_lists_trigger = [[trigger]+sample for sample in sample_token_lists] | |
val_encodings = self.__tokenizer(sample_token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt" | |
val_labels = [] | |
for i in range(len(sample_token_lists_trigger)): | |
word_ids = val_encodings.word_ids(batch_index=i) | |
label_ids = [0 for _ in word_ids] | |
val_labels.append(label_ids) | |
val_dataset = TokenClassificationDataset(val_encodings, val_labels) | |
predictions_raw, _, _ = self.__trainer.predict(val_dataset) | |
predictions_align = self.__align_predictions(predictions_raw, val_encodings) | |
confidence = [[max(token) for token in sample] for sample in predictions_align] | |
predictions_sample = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align] | |
predictions_part = [] | |
for tok, pred in zip(sample_token_lists_trigger, predictions_sample): | |
if trigger == "[REL]" and "[SUB]" not in tok: | |
predictions_part += [self.__entity_type_set.id_of_non_entity] * len(pred) | |
else: | |
predictions_part += pred | |
predictions.append(predictions_part) | |
# predictions.append([j for i in predictions_sample for j in i])) | |
else: | |
token_lists_trigger = [[trigger]+sample for sample in token_lists] | |
val_encodings = self.__tokenizer(token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True) # return_tensors="pt" | |
val_labels = [] | |
for i in range(len(token_lists_trigger)): | |
word_ids = val_encodings.word_ids(batch_index=i) | |
label_ids = [0 for _ in word_ids] | |
val_labels.append(label_ids) | |
val_dataset = TokenClassificationDataset(val_encodings, val_labels) | |
predictions_raw, _, _ = self.__trainer.predict(val_dataset) | |
predictions_align = self.__align_predictions(predictions_raw, val_encodings) | |
confidence = [[max(token) for token in sample] for sample in predictions_align] | |
predictions = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align] | |
return predictions | |
def __do_split_sentences(self, tokens_: List[List[str]], split_len_ = 200) -> List[List[List[str]]]: | |
# split token lists into shorter lists | |
res_tokens = [] | |
for tok_lst in tokens_: | |
res_tokens_sample = [] | |
length = len(tok_lst) | |
if length > split_len_: | |
num_lists = length // split_len_ + (1 if (length % split_len_) > 0 else 0) | |
new_length = int(length / num_lists) + 1 | |
self.__logger.info(f"Splitting a list of {length} elements into {num_lists} lists of length {new_length}..") | |
start_idx = 0 | |
for i in range(num_lists): | |
end_idx = min(start_idx + new_length, length) | |
if "\n" in tok_lst[start_idx]: tok_lst[start_idx] = "." | |
if "\n" in tok_lst[end_idx-1]: tok_lst[end_idx-1] = "." | |
res_tokens_sample.append(tok_lst[start_idx:end_idx]) | |
start_idx = end_idx | |
res_tokens.append(res_tokens_sample) | |
else: | |
res_tokens.append([tok_lst]) | |
return res_tokens | |
def __align_predictions(self, predictions, tokenized_inputs, sum_all_tokens=False) -> List[List[List[float]]]: | |
""" Align predicted labels from Transformer Tokenizer """ | |
confidence = [] | |
id_of_non_entity = self.__entity_type_set.id_of_non_entity | |
for i, tagset in enumerate(predictions): | |
word_ids = tokenized_inputs.word_ids(batch_index=i) | |
previous_word_idx = None | |
token_confidence = [] | |
for k, word_idx in enumerate(word_ids): | |
try: | |
tok_conf = [value for value in tagset[k]] | |
except TypeError: | |
# use the object itself it if's not iterable | |
tok_conf = tagset[k] | |
if word_idx is not None: | |
# add nonentity tokens if there is a gap in word ids (usually caused by a newline token) | |
if previous_word_idx is not None: | |
diff = word_idx - previous_word_idx | |
for i in range(diff - 1): | |
tmp = [0 for _ in tok_conf] | |
tmp[id_of_non_entity] = 1.0 | |
token_confidence.append(tmp) | |
# add confidence value if this is the first token of the word | |
if word_idx != previous_word_idx: | |
token_confidence.append(tok_conf) | |
else: | |
# if sum_all_tokens=True the confidence for all tokens of one word will be summarized | |
if sum_all_tokens: | |
token_confidence[-1] = [a + b for a, b in zip(token_confidence[-1], tok_conf)] | |
previous_word_idx = word_idx | |
confidence.append(token_confidence) | |
return confidence | |