Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| from transformers import BertTokenizerFast, BertForTokenClassification | |
| import streamlit as st | |
| from transformers import CharSpan | |
| import torch | |
| COLOURS = { | |
| "ADJ": '#1DB5D7', | |
| "ART": '#47A455', | |
| "ELI": '#FF5252', | |
| "FIN": '#9873E6', | |
| "NEG": '#FF914D' | |
| } | |
| def load_model(): | |
| print('Loading classification model') | |
| return BertForTokenClassification.from_pretrained("alice-hml/mBERT_grammatical_error_tagger") | |
| def load_tokenizer(): | |
| print("Loading tokenizer for classification model") | |
| return BertTokenizerFast.from_pretrained("alice-hml/mBERT_grammatical_error_tagger") | |
| model = load_model() | |
| tokenizer = load_tokenizer() | |
| special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token] | |
| def predict(logits, threshold=0.3): | |
| """Takes in output of model for a sentence as entry, returns one hot encoded prediction""" | |
| probabilities = torch.nn.Softmax(dim=2)(logits) | |
| predicted = (probabilities > threshold).long() | |
| # if O is among the predicted labels, do not return other possible error types | |
| for i,sent in enumerate(predicted): | |
| for j,tok in enumerate(sent): | |
| if tok[0] == 1: | |
| predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1)) | |
| return predicted | |
| def one_hot_label_decode(tok_labels): | |
| labels = [] | |
| for i,v in enumerate(tok_labels): | |
| if v != 0: | |
| labels.append(model.config.id2label[i]) | |
| return labels | |
| def merge_char_spans(charspans: list): | |
| sorted_charspans = sorted(charspans) | |
| # make groups of continuous spans | |
| start = min((span.start for span in sorted_charspans)) | |
| end = max((span.end for span in sorted_charspans)) | |
| return CharSpan(start,end) | |
| def merge_error_spans(errors): | |
| merged_spans = defaultdict(list) | |
| for error_type in errors: | |
| merged_spans[error_type] = errors[error_type][:1] | |
| for charspan in errors[error_type][1:]: | |
| if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start: | |
| merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan]) | |
| return merged_spans | |
| def predict_error_spans(sentence): | |
| encoded_orig = tokenizer(sentence, return_tensors="pt") | |
| tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0]) | |
| output = model(**encoded_orig) | |
| predicted_labels = predict(output.logits, threshold=0.2) | |
| errors = defaultdict(list) | |
| for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))): | |
| labels = one_hot_label_decode(p) | |
| if t not in special_tokens and labels != ['O'] : | |
| for error_type in labels: | |
| position, e_type = error_type.split("-") | |
| if position == "B": | |
| errors[e_type].append(encoded_orig.token_to_chars(i)) | |
| else: | |
| errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)]) | |
| return sentence, merge_error_spans(errors) | |
| def annotate(sentence: str): | |
| sent_and_error_spans = predict_error_spans(sentence) | |
| error_spans_ordered = [] | |
| for error_type, charspans in sent_and_error_spans[1].items(): | |
| for charspan in charspans: | |
| error_spans_ordered.append((charspan,error_type)) | |
| error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0]) | |
| annotated_sentence = [] | |
| for i,(error_span,label) in enumerate(error_spans_ordered): | |
| if i > 0: | |
| annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start]) | |
| else: | |
| annotated_sentence.append(sentence[:error_span.start]) | |
| annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label])) | |
| if len(error_spans_ordered) > 0: | |
| annotated_sentence.append(sentence[error_span.end:]) | |
| else: | |
| annotated_sentence.append(sentence) | |
| return annotated_sentence |