from collections import defaultdict from transformers import BertTokenizerFast, BertForTokenClassification import streamlit as st from transformers import CharSpan import torch COLOURS = { "ADJ": '#1DB5D7', "ART": '#47A455', "ELI": '#FF5252', "FIN": '#9873E6', "NEG": '#FF914D' } @st.cache(allow_output_mutation=True) def load_model(): print('Loading classification model') return BertForTokenClassification.from_pretrained("classifier/") @st.cache(allow_output_mutation=True) def load_tokenizer(): print("Loading tokenizer for classification model") return BertTokenizerFast.from_pretrained("classifier/") model = load_model() tokenizer = load_tokenizer() special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token] def predict(logits, threshold=0.3): """Takes in output of model for a sentence as entry, returns one hot encoded prediction""" probabilities = torch.nn.Softmax(dim=2)(logits) predicted = (probabilities > threshold).long() # if O is among the predicted labels, do not return other possible error types for i,sent in enumerate(predicted): for j,tok in enumerate(sent): if tok[0] == 1: predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1)) return predicted def one_hot_label_decode(tok_labels): labels = [] for i,v in enumerate(tok_labels): if v != 0: labels.append(model.config.id2label[i]) return labels def merge_char_spans(charspans: list): sorted_charspans = sorted(charspans) # make groups of continuous spans start = min((span.start for span in sorted_charspans)) end = max((span.end for span in sorted_charspans)) return CharSpan(start,end) def merge_error_spans(errors): merged_spans = defaultdict(list) for error_type in errors: merged_spans[error_type] = errors[error_type][:1] for charspan in errors[error_type][1:]: if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start: merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan]) return merged_spans def predict_error_spans(sentence): encoded_orig = tokenizer(sentence, return_tensors="pt") tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0]) output = model(**encoded_orig) predicted_labels = predict(output.logits, threshold=0.2) errors = defaultdict(list) for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))): labels = one_hot_label_decode(p) if t not in special_tokens and labels != ['O'] : for error_type in labels: position, e_type = error_type.split("-") if position == "B": errors[e_type].append(encoded_orig.token_to_chars(i)) else: errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)]) return sentence, merge_error_spans(errors) def annotate(sentence: str): sent_and_error_spans = predict_error_spans(sentence) error_spans_ordered = [] for error_type, charspans in sent_and_error_spans[1].items(): for charspan in charspans: error_spans_ordered.append((charspan,error_type)) error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0]) annotated_sentence = [] for i,(error_span,label) in enumerate(error_spans_ordered): if i > 0: annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start]) else: annotated_sentence.append(sentence[:error_span.start]) annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label])) if len(error_spans_ordered) > 0: annotated_sentence.append(sentence[error_span.end:]) else: annotated_sentence.append(sentence) return annotated_sentence