Spaces:
Runtime error
Runtime error
from collections import defaultdict | |
from transformers import BertTokenizerFast, BertForTokenClassification | |
import streamlit as st | |
from transformers import CharSpan | |
import torch | |
COLOURS = { | |
"ADJ": '#1DB5D7', | |
"ART": '#47A455', | |
"ELI": '#FF5252', | |
"FIN": '#9873E6', | |
"NEG": '#FF914D' | |
} | |
def load_model(): | |
print('Loading classification model') | |
return BertForTokenClassification.from_pretrained("alice-hml/mBERT_grammatical_error_tagger") | |
def load_tokenizer(): | |
print("Loading tokenizer for classification model") | |
return BertTokenizerFast.from_pretrained("alice-hml/mBERT_grammatical_error_tagger") | |
model = load_model() | |
tokenizer = load_tokenizer() | |
special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token] | |
def predict(logits, threshold=0.3): | |
"""Takes in output of model for a sentence as entry, returns one hot encoded prediction""" | |
probabilities = torch.nn.Softmax(dim=2)(logits) | |
predicted = (probabilities > threshold).long() | |
# if O is among the predicted labels, do not return other possible error types | |
for i,sent in enumerate(predicted): | |
for j,tok in enumerate(sent): | |
if tok[0] == 1: | |
predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1)) | |
return predicted | |
def one_hot_label_decode(tok_labels): | |
labels = [] | |
for i,v in enumerate(tok_labels): | |
if v != 0: | |
labels.append(model.config.id2label[i]) | |
return labels | |
def merge_char_spans(charspans: list): | |
sorted_charspans = sorted(charspans) | |
# make groups of continuous spans | |
start = min((span.start for span in sorted_charspans)) | |
end = max((span.end for span in sorted_charspans)) | |
return CharSpan(start,end) | |
def merge_error_spans(errors): | |
merged_spans = defaultdict(list) | |
for error_type in errors: | |
merged_spans[error_type] = errors[error_type][:1] | |
for charspan in errors[error_type][1:]: | |
if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start: | |
merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan]) | |
return merged_spans | |
def predict_error_spans(sentence): | |
encoded_orig = tokenizer(sentence, return_tensors="pt") | |
tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0]) | |
output = model(**encoded_orig) | |
predicted_labels = predict(output.logits, threshold=0.2) | |
errors = defaultdict(list) | |
for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))): | |
labels = one_hot_label_decode(p) | |
if t not in special_tokens and labels != ['O'] : | |
for error_type in labels: | |
position, e_type = error_type.split("-") | |
if position == "B": | |
errors[e_type].append(encoded_orig.token_to_chars(i)) | |
else: | |
errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)]) | |
return sentence, merge_error_spans(errors) | |
def annotate(sentence: str): | |
sent_and_error_spans = predict_error_spans(sentence) | |
error_spans_ordered = [] | |
for error_type, charspans in sent_and_error_spans[1].items(): | |
for charspan in charspans: | |
error_spans_ordered.append((charspan,error_type)) | |
error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0]) | |
annotated_sentence = [] | |
for i,(error_span,label) in enumerate(error_spans_ordered): | |
if i > 0: | |
annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start]) | |
else: | |
annotated_sentence.append(sentence[:error_span.start]) | |
annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label])) | |
if len(error_spans_ordered) > 0: | |
annotated_sentence.append(sentence[error_span.end:]) | |
else: | |
annotated_sentence.append(sentence) | |
return annotated_sentence |